TorchTPU는 Google이 개발한 PyTorch-to-TPU 통합 스택이다. 기존 PyTorch 코드를 최소한의 수정으로 TPU에서 실행할 수 있게 해주며, 사용성·이식성·성능을 동시에 추구한다. Gemini, Veo 같은 Google 내부 모델 학습에 사용된 TPU 인프라를 PyTorch 개발자에게 개방한다. 2026년 4월 구글 개발자 블로그를 통해 공개됐으며, 공개 GitHub 저장소와 튜토리얼은 2026년 중 출시 예정이다.
설계 원칙: “느낌이 PyTorch여야 한다”
기존 PyTorch 스크립트에서 초기화만 "tpu"로 바꿔 실행할 수 있는 것을 목표로 한다. PrivateUse1 인터페이스를 통해 서브클래스나 래퍼 없이 일반 PyTorch 텐서가 TPU에서 동작한다.
Eager First: 세 가지 실행 모드
| 모드 | 특징 | 용도 |
|---|---|---|
| Debug Eager | 연산마다 동기화, 느리지만 정확 | NaN·OOM·형상 불일치 디버깅 |
| Strict Eager | 비동기 실행, PyTorch와 동일한 경험 | 기본 개발·테스트 |
| Fused Eager | 연산 스트림을 자동 융합해 TensorCore 활용 극대화 | 프로덕션 추론·학습 |
Fused Eager는 별도 설정 없이 Strict Eager 대비 50~100% 이상 성능 향상을 제공한다. TorchTPU가 실행 중인 연산 스트림을 분석해 자동으로 융합하며, 캐시된 컴파일 결과를 재사용해 반복 실행 비용을 줄인다.
정적 컴파일: torch.compile + XLA
최대 성능이 필요할 때는 torch.compile을 통해 전체 그래프를 XLA로 컴파일한다.
import torch
import torch_tpu # TorchTPU 통합
# 기존 코드 최소 수정
device = torch.device("tpu")
model = MyModel().to(device)
# torch.compile로 XLA 최적화
compiled_model = torch.compile(model, backend="torchtpu")컴파일 경로: Torch Dynamo → FX 그래프 캡처 → XLA 백엔드 → StableHLO IR → TPU 최적 바이너리
커스텀 커널이 필요한 경우 Pallas(JAX 기반)로 저수준 TPU 명령을 직접 제어할 수 있다.
import jax
import torch_tpu
@torch_tpu.pallas.custom_jax_kernel
def my_kernel(x, y):
# Pallas로 작성한 커스텀 TPU 커널
return jax.lax.dot(x, y)분산 학습
DDP, FSDPv2, DTensor를 기본 지원한다. 기존 PyTorch 분산 API 기반 서드파티 라이브러리 대부분이 변경 없이 동작한다.
이전 PyTorch/XLA의 한계였던 MPMD(rank 0에서의 추가 로깅 등 분기된 실행) 지원을 강화해, SPMD에서 벗어나는 코드도 자동으로 올바르게 처리한다.
TPU 하드웨어 특성 고려
- 현재 TPU는 어텐션 헤드 차원 128 또는 256에서 행렬 연산 효율이 최대
- 많은 모델이 하드코딩한 64 차원보다 128 이상이 권장
- TorchTPU 권장 워크플로: 정확한 실행 확인 (Debug/Strict Eager) → 성능 최적화 (Fused Eager/compile)
기존 PyTorch/XLA와 차이점
| 항목 | PyTorch/XLA (이전) | TorchTPU (신규) |
|---|---|---|
| 실행 방식 | SPMD 전용 | SPMD + MPMD 지원 |
| 텐서 타입 | 커스텀 XLATensor | 일반 PyTorch Tensor |
| Eager 모드 | 제한적 | 3단계 모드 (Debug/Strict/Fused) |
| 분기 코드 | 수동 처리 필요 | 자동 격리 처리 |
2026년 로드맵
- 동적 시퀀스 길이·배치 크기 재컴파일 최소화 (bounded dynamism)
- 표준 연산 프리컴파일 커널 라이브러리
- vLLM, TorchTitan과 깊은 통합
- Helion DSL 지원 확장
- torch.compile에서 동적 형상 네이티브 지원
- 공개 GitHub 저장소 + 재현 가능한 튜토리얼 출시
누가 쓰면 좋은가
| 사용자 | 사용 사례 |
|---|---|
| ML 연구자 | 기존 PyTorch 코드베이스를 TPU로 마이그레이션 |
| Google Cloud 사용자 | TPU Pod 규모의 대규모 모델 학습·서빙 |
| 추론 최적화 엔지니어 | Fused Eager + XLA로 서빙 지연 최소화 |
관련 문서
- litert-lm — Google의 온디바이스 LLM 런타임 LiteRT-LM
참고 자료
- TorchTPU: Running PyTorch Natively on TPUs at Google Scale — Google Developers Blog (2026-04-07)