AI Sparkup

최신 AI 쉽게 깊게 따라잡기⚡

TorchTPU – Google TPU에서 PyTorch를 네이티브로 실행하는 엔지니어링 스택

TorchTPU는 Google이 개발한 PyTorch-to-TPU 통합 스택이다. 기존 PyTorch 코드를 최소한의 수정으로 TPU에서 실행할 수 있게 해주며, 사용성·이식성·성능을 동시에 추구한다. Gemini, Veo 같은 Google 내부 모델 학습에 사용된 TPU 인프라를 PyTorch 개발자에게 개방한다. 2026년 4월 구글 개발자 블로그를 통해 공개됐으며, 공개 GitHub 저장소와 튜토리얼은 2026년 중 출시 예정이다.

설계 원칙: “느낌이 PyTorch여야 한다”

기존 PyTorch 스크립트에서 초기화만 "tpu"로 바꿔 실행할 수 있는 것을 목표로 한다. PrivateUse1 인터페이스를 통해 서브클래스나 래퍼 없이 일반 PyTorch 텐서가 TPU에서 동작한다.

Eager First: 세 가지 실행 모드

모드특징용도
Debug Eager연산마다 동기화, 느리지만 정확NaN·OOM·형상 불일치 디버깅
Strict Eager비동기 실행, PyTorch와 동일한 경험기본 개발·테스트
Fused Eager연산 스트림을 자동 융합해 TensorCore 활용 극대화프로덕션 추론·학습

Fused Eager는 별도 설정 없이 Strict Eager 대비 50~100% 이상 성능 향상을 제공한다. TorchTPU가 실행 중인 연산 스트림을 분석해 자동으로 융합하며, 캐시된 컴파일 결과를 재사용해 반복 실행 비용을 줄인다.

정적 컴파일: torch.compile + XLA

최대 성능이 필요할 때는 torch.compile을 통해 전체 그래프를 XLA로 컴파일한다.

import torch
import torch_tpu  # TorchTPU 통합

# 기존 코드 최소 수정
device = torch.device("tpu")
model = MyModel().to(device)

# torch.compile로 XLA 최적화
compiled_model = torch.compile(model, backend="torchtpu")

컴파일 경로: Torch Dynamo → FX 그래프 캡처 → XLA 백엔드StableHLO IR → TPU 최적 바이너리

커스텀 커널이 필요한 경우 Pallas(JAX 기반)로 저수준 TPU 명령을 직접 제어할 수 있다.

import jax
import torch_tpu

@torch_tpu.pallas.custom_jax_kernel
def my_kernel(x, y):
    # Pallas로 작성한 커스텀 TPU 커널
    return jax.lax.dot(x, y)

분산 학습

DDP, FSDPv2, DTensor를 기본 지원한다. 기존 PyTorch 분산 API 기반 서드파티 라이브러리 대부분이 변경 없이 동작한다.

이전 PyTorch/XLA의 한계였던 MPMD(rank 0에서의 추가 로깅 등 분기된 실행) 지원을 강화해, SPMD에서 벗어나는 코드도 자동으로 올바르게 처리한다.

TPU 하드웨어 특성 고려

  • 현재 TPU는 어텐션 헤드 차원 128 또는 256에서 행렬 연산 효율이 최대
  • 많은 모델이 하드코딩한 64 차원보다 128 이상이 권장
  • TorchTPU 권장 워크플로: 정확한 실행 확인 (Debug/Strict Eager) → 성능 최적화 (Fused Eager/compile)

기존 PyTorch/XLA와 차이점

항목PyTorch/XLA (이전)TorchTPU (신규)
실행 방식SPMD 전용SPMD + MPMD 지원
텐서 타입커스텀 XLATensor일반 PyTorch Tensor
Eager 모드제한적3단계 모드 (Debug/Strict/Fused)
분기 코드수동 처리 필요자동 격리 처리

2026년 로드맵

  • 동적 시퀀스 길이·배치 크기 재컴파일 최소화 (bounded dynamism)
  • 표준 연산 프리컴파일 커널 라이브러리
  • vLLM, TorchTitan과 깊은 통합
  • Helion DSL 지원 확장
  • torch.compile에서 동적 형상 네이티브 지원
  • 공개 GitHub 저장소 + 재현 가능한 튜토리얼 출시

누가 쓰면 좋은가

사용자사용 사례
ML 연구자기존 PyTorch 코드베이스를 TPU로 마이그레이션
Google Cloud 사용자TPU Pod 규모의 대규모 모델 학습·서빙
추론 최적화 엔지니어Fused Eager + XLA로 서빙 지연 최소화

관련 문서

  • litert-lm — Google의 온디바이스 LLM 런타임 LiteRT-LM

참고 자료



AI Sparkup 구독하기

최신 게시물 요약과 더 심층적인 정보를 이메일로 받아 보세요! (무료)