PyTorch 사용자 정의 연산(Custom Operation)

2026년 5월 10일 게재된 이 기술 블로그에서는 PyTorch에서 C++ 및 CUDA를 사용해 고성능 사용자 정의 연산(Custom Operation)을 구현하는 방법을 다룹니다. 구체적으로 'my_ops' 네임스페이스 하위의 'identity_conv_op' 연산을 예시로 들어, CPU에서는 복사를 진행하고 GPU(CUDA) 환경에서는 256개 스레드 레이아웃과 Half 및 BFloat16 정밀도를 지원하는 디스패칭 구현 방식을 코드를 통해 상세히 설명합니다.

AI 요약

본 글은 PyTorch 모델의 성능 최적화를 위해 C++ 및 CUDA를 사용하여 사용자 정의 연산(Custom Operation)을 설계하고 이를 모델 및 AOTInductor 컴파일 추론 프로그램에 등록하는 실무적인 방법을 다룹니다. 저자는 간단한 'ID 합성곱(identity convolution)'의 처리 방식을 예시로 들어 설명합니다. CPU 환경에서는 identity_conv_cpu_impl을 통해 단순 텐서 복제(clone())를 수행하고, CUDA 환경에서는 디바이스 데이터를 분석해 맞춤형 커널을 실행하는 identity_conv_cuda_impl 방식을 채택합니다. 이 두 가지 구현체는 TORCH_LIBRARY_IMPL 매크로를 사용하여 각각 CPU와 CUDA 환경에 동적으로 매핑되도록 등록됩니다. 이를 통해 개발자는 고성능 모델의 배포 환경에서 텐서의 디바이스 종류에 따라 올바른 구현체로 연산을 자동 분기시키는 효율적인 파이프라인을 구축할 수 있습니다.

핵심 인사이트

  • 디바이스 독립적 매핑: TORCH_LIBRARY_IMPL 매크로를 사용하여 my_ops 라이브러리 내에 identity_conv_op 연산을 등록하고, CPU와 CUDA 장치에 맞는 구현체를 동적으로 분기(Dispatch)시킵니다.
  • 정밀도 다중 지원: CUDA 구현 내에서 AT_DISPATCH_FLOATING_TYPES_AND2 매크로를 활용하여 기본 부동 소수점 타입뿐만 아니라 at::ScalarType::Halfat::ScalarType::BFloat16 타입까지 모두 처리 가능하도록 구조화했습니다.
  • 최적의 스레드 구성: CUDA 연산을 최적화하기 위해 고정 상수 constexpr int kThreads = 256를 정의하여 하나의 블록당 256개의 스레드를 할당하는 병렬 처리 아키텍처를 채택했습니다.

주요 디테일

  • CPU 예외 처리 및 구현: CPU용 연산 함수인 identity_conv_cpu_impl은 입력된 텐서가 CUDA 디바이스에 존재하지 않는지 TORCH_CHECK로 검증한 뒤 안전하게 복제(clone()) 텐서를 반환합니다.
  • CUDA 메타데이터 업로드: identity_conv_cuda_impl 함수는 호스트 측 디스패처로서, 입력 텐서의 모양(Shape)과 스트라이드(Strides) 메타데이터를 담은 디바이스 텐서인 shape_devstrides_dev를 직접 생성하여 GPU 커널로 전달합니다.
  • 출력 메모리 할당: CUDA 연산 속도를 보존하기 위해 입력 텐서와 동일한 메타데이터 구조를 갖춘 빈 텐서를 빠르게 생성하는 torch::empty_like(input) API를 활용합니다.
  • 커널 론칭 에러 검증: GPU 커널 비동기 실행 이후 발생할 수 있는 잠재적 오류를 캐치하기 위해 PyTorch 표준 매크로인 C10_CUDA_KERNEL_LAUNCH_CHECK()를 호출하여 안정성을 확보합니다.

향후 전망

  • PyTorch 2.x 생태계에서 핵심으로 떠오른 AOTInductor 컴파일러 등 컴파일 기반의 고성능 최적화 추론 프로그램에서 이러한 C++/CUDA 커스텀 연산 결합 구조가 널리 쓰일 것으로 보입니다.
  • 다양한 인공지능 모델 설계에서 정밀도(BFloat16, Half) 요구량이 다양해짐에 따라 커스텀 오퍼레이터의 유연한 타입 바인딩 기술의 적용이 확대될 전망입니다.
Share

이것도 읽어보세요

댓글

이 소식에 대한 의견을 자유롭게 남겨주세요.

댓글 (0)

불러오는 중...