AI Sparkup

최신 AI 쉽게 깊게 따라잡기⚡

DiffusionGemma 튜토리얼 – vLLM 서빙과 JAX 파인튜닝 실전 개발자 가이드

diffusion-gemma는 256토큰 블록을 병렬로 생성해 기존 자기회귀 모델 대비 최대 4배 빠른 텍스트 생성을 달성한다. 이 문서는 실제 서빙과 파인튜닝에 초점을 맞춘 개발자 가이드다.

모델 사양

  • 아키텍처: Gemma 4 26B MoE 기반
  • 활성 파라미터: 3.8B (추론 시 희소 활성화)
  • VRAM 요구: 양자화 후 18GB 이내에서 배포 가능
  • 생성 속도: NVIDIA GeForce RTX 5090에서 700+ tokens/s, NVIDIA H100에서 1000+ tokens/s

핵심 아키텍처 원리

256토큰 Canvas 병렬 생성

자기회귀(AR) 모델은 토큰을 하나씩 생성하면서 매 스텝마다 모델 가중치를 메모리에서 로드한다. 병목은 메모리 대역폭이다.

DiffusionGemma는 256개 토큰의 플레이스홀더로 시작해 이 전체 canvas를 병렬로 반복 정제(denoising)한다. 병목이 메모리 대역폭에서 컴퓨트로 이동하면서 텐서코어를 최대 활용할 수 있다.

Uniform State Diffusion

랜덤 플레이스홀더 토큰으로 canvas를 초기화하고, 여러 번의 denoising 패스를 거쳐 고신뢰 토큰이 인접 위치 확정을 돕는다.

자기수정도 가능하다: 특정 토큰의 신뢰도가 낮아지면 샘플러가 해당 토큰을 다시 무작위화해 재시도한다. AR 모델에는 없는 기능이다.

Block Autoregressive Diffusion (256토큰 초과 시)

256토큰을 넘는 시퀀스는 블록 단위로 처리한다:

  1. 첫 256토큰 canvas를 완전히 denoising
  2. 완성된 블록을 KV 캐시에 커밋
  3. 다음 256토큰 canvas 초기화 (이전 블록을 컨텍스트로 참조)
  4. 반복

병렬 속도와 자기회귀 안정성을 결합한 방식이다.

Prefill과 Denoising의 이중 어텐션

단계어텐션 종류역할
PrefillCausal (인과적)프롬프트 컨텍스트를 KV 캐시에 기록
DenoisingBidirectional (양방향)Canvas 전체를 동시에 평가하며 정제

양방향 어텐션 덕분에 canvas의 모든 위치가 다른 모든 위치를 참조할 수 있다. Sudoku처럼 전역 제약이 있는 문제에 특히 유리하다.

vLLM으로 서빙하기

Google이 vLLM 팀과 협력해 DiffusionGemma를 vLLM에 통합했다.

vllm serve google/diffusiongemma-26B-A4B-it \
  --max-model-len 262144 \
  --max-num-seqs 4 \
  --gpu-memory-utilization 0.85 \
  --attention-backend TRITON_ATTN \
  --generation-config vllm \
  --hf-overrides '{"diffusion_sampler": "entropy_bound", "diffusion_entropy_bound": 0.1}' \
  --diffusion-config '{"canvas_length": 256}' \
  --enable-chunked-prefill

주요 플래그:

  • --attention-backend TRITON_ATTN — bidirectional attention 지원을 위한 Triton 백엔드
  • --diffusion-sampler entropy_bound — 신뢰도 기반 조기 종료로 레이턴시 감소
  • --canvas_length 256 — 기본 블록 크기 (변경 가능)

JAX로 파인튜닝하기: Sudoku Solver 예제

Google이 공개한 파인튜닝 레시피와 Hackable Diffusion JAX 툴킷을 활용한다.

왜 Sudoku인가

81칸 격자 문제는 모든 자리가 가로·세로·3×3 박스 제약으로 얽혀 있다. AR 모델은 왼쪽에서 오른쪽으로만 생성하므로 미래 제약을 고려하지 못하고 백트래킹도 불가능하다. DiffusionGemma의 양방향 어텐션과 자기수정 능력이 이 문제에서 명확한 우위를 보인다.

파인튜닝 결과

모델정답률평균 denoising 스텝
베이스 DiffusionGemma~0%48 스텝
SFT 파인튜닝 후80%12 스텝 (조기 종료)

SFT 어댑터가 모델을 더 빨리 수렴하게 만들어, 조기 종료 효과로 레이턴시까지 줄였다.

파인튜닝 코드 위치

# Gemma 공식 저장소의 diffusion 디렉터리
https://github.com/google-deepmind/gemma/tree/main/gemma/diffusion

기존 서빙 프레임워크와의 통합

Gemma 4 26B A4B 모델과 동일한 아키텍처를 공유하기 때문에 denoising 스텝만 추가로 구현하면 기존 서빙 프레임워크에 통합할 수 있다.

누가 쓰면 좋을까

  • GPU 서버 운영자: 메모리 대역폭이 아닌 컴퓨트가 풍부한 H100/RTX 5090 환경에서 텍스트 생성 throughput을 극대화하고 싶은 경우
  • 제약 기반 생성 연구자: 양방향 어텐션과 자기수정 능력이 필요한 태스크 (코드 완성, 템플릿 채우기, 퍼즐)
  • 파인튜닝 실험자: JAX 기반 diffusion 모델 커스터마이징을 탐색하고 싶은 경우

참고 자료



AI Sparkup 구독하기

최신 게시물 요약과 더 심층적인 정보를 이메일로 받아 보세요! (무료)