AI 요약
이 글은 LLM 내부 구조 시리즈의 5번째 포스팅으로, GPU(Triton)에서 구현된 Flash Attention 알고리즘을 구글 TPU 및 JAX 환경으로 이식하는 과정을 다룹니다. 저자는 JAX를 단순히 '컴파일된 NumPy'로 생각하고 접근했으나, 실제로는 XLA 컴파일러의 HLO(High-Level Operations) 변환 과정과 TPU 특유의 Systolic Array 구조로 인해 완전히 다른 접근이 필요함을 깨닫습니다. 특히 Triton에서 사용하던 가변 포인터 산술 연산(tl.store) 대신 JAX의 lax.dynamic_update_slice와 lax.fori_loop를 사용해야 하는 제약 사항을 상세히 설명합니다. 결과적으로 JAX/XLA는 추적(tracing)을 통해 순수 계산 그래프를 생성하며, 이 과정에서 발생하는 연산 융합(Fusion)과 VLIW 명령어 생성 원리를 파헤칩니다. 또한, 컴파일러의 한계를 넘어서기 위한 Pallas 라이브러리의 필요성을 언급하며 TPU 프로그래밍 모델의 본질을 조명합니다.
핵심 인사이트
- 날짜 및 배경: 2026년 3월 6일에 공개된 이 글은 LLM 내부 구조를 다루는 시리즈의 5부로, 무료 Google Colab TPU 런타임을 활용해 실험을 진행했습니다.
- HLO 프리미티브: JAX 코드는
jax.jit을 통해 약 100여 개의 프리미티브(dot, reduce, broadcast 등)로 구성된 HLO(High-Level Operations) 그래프로 변환됩니다. - 컴파일러 아키텍처: XLA 컴파일러는 GPU를 위해 PTX 코드를 생성하는 반면, TPU를 위해서는 VLIW(Very Long Instruction Word) 명령어를 생성합니다.
- 루프 처리의 차이: Python의 일반적인
for루프는 추적 시점에 언롤링(unrolling)되어 그래프 크기를 폭증시키므로, 이를 방지하기 위해jax.lax.fori_loop를 반드시 사용해야 합니다.
주요 디테일
- 메모리 불변성: Triton은
tl.store를 통해 가변 포인터에 직접 쓰기가 가능하지만, JAX 배열은 불변(Immutable)이므로return (new_max, new_sum, new_acc)와 같이 상태를 반환하는 함수형 방식을 취합니다. - 연산 융합(Fusion): XLA는 연속된 요소별(elementwise) 연산을 단일 커널로 융합하여 중간 데이터가 HBM(High Bandwidth Memory)에 머무르는 것을 방지합니다.
- 데이터 흐름 제어: Triton에서는
program_id와 포인터 산술을 개발자가 직접 통제하지만, JAX/XLA 환경에서는 컴파일러가 데이터 이동과 하드웨어 매핑을 결정합니다. - Systolic Array: TPU의 핵심 연산 장치인 Systolic Array에서 데이터가 어떻게 흐르는지 에뮬레이터를 통해 분석하고, 이것이 Flash Attention 성능에 미치는 영향을 고찰했습니다.
- Pallas의 역할: 컴파일러가 생성한 코드보다 더 높은 성능을 내기 위해서는 TPU 하드웨어에 더 밀접하게 접근할 수 있는 Pallas 라이브러리 활용이 필수적입니다.
향후 전망
- TPU 프로그래밍의 대중화: Google Colab의 무료 TPU 지원으로 인해 CUDA/Triton에 집중되었던 커스텀 커널 개발이 JAX/Pallas 생태계로 확장될 것으로 보입니다.
- 컴파일러 최적화 진화: XLA와 같은 고수준 컴파일러가 수동으로 최적화된 하드웨어 커널 성능에 점차 근접하며 추상화 계층의 효율성이 높아질 전망입니다.
출처:hackernews
