gemma 4의 추론 속도를 획기적으로 높이고 싶은 ML 엔지니어와 연구자를 위한 튜토리얼이다. Gemma 4는 메인 모델(타겟 모델)과 함께 경량 드래프터(drafter) 모델을 새로 공개했다. 이 드래프터가 여러 토큰을 미리 예측하면 타겟 모델이 한 번의 포워드 패스로 이를 검증하는 방식이 바로 Multi-Token Prediction(MTP), 즉 추측 디코딩(speculative decoding)이다. HuggingFace Transformers만으로 코드 몇 줄에 이 기능을 적용하는 방법을 단계별로 살펴본다.
사전 준비
- Python 3.10 이상 권장
- GPU 환경 (CPU에서도 동작하지만 속도 향상 효과는 GPU에서 극대화됨)
- HuggingFace 계정 및 Gemma 모델 접근 권한
pip install torch accelerate transformers- Gemma 모델 사용을 위해 HuggingFace 모델 허브에서 라이선스 동의 필요
MTP(Multi-Token Prediction)란?
전통적인 자동회귀(autoregressive) 디코딩은 타겟 모델이 토큰을 하나씩 순차적으로 생성한다. MTP는 이 과정을 두 단계로 분리한다.
- 드래프터(drafter)가 여러 토큰을 자동회귀 방식으로 빠르게 예측한다.
- 타겟 모델(target model)이 드래프터의 예측 토큰들을 한 번의 포워드 패스로 병렬 검증한다.
확률이 높은 토큰은 수락되고, 낮은 토큰은 거부된다. 거부된 이후의 토큰은 모두 무시된다. 타겟 모델은 항상 최소 1개의 토큰을 자체 생성하므로 출력 품질은 타겟 모델 수준을 그대로 유지한다. Gemma 4 드래프터는 타겟 모델의 활성화(activations)와 KV 캐시를 활용해 예측 정확도를 높였다.
1단계: 모델 로드
타겟 모델과 드래프터 모델을 각각 로드한다. 드래프터 모델 ID는 타겟 모델 ID에 -assistant를 붙이면 된다.
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
TARGET_MODEL_ID = "google/gemma-4-E2B-it"
# 선택 가능 모델: "google/gemma-4-E2B-it", "google/gemma-4-E4B-it",
# "google/gemma-4-31B-it", "google/gemma-4-26B-A4B-it"
ASSISTANT_MODEL_ID = TARGET_MODEL_ID + "-assistant"
# 타겟 모델 (메인 Gemma 4 모델)
processor = AutoProcessor.from_pretrained(TARGET_MODEL_ID)
target_model = AutoModelForCausalLM.from_pretrained(
TARGET_MODEL_ID,
dtype=torch.bfloat16,
device_map="auto",
)
# 어시스턴트 모델 (경량 4-레이어 MTP 드래프터)
assistant_model = AutoModelForCausalLM.from_pretrained(
ASSISTANT_MODEL_ID,
dtype=torch.bfloat16,
device_map="auto",
)Gemma 4 드래프터는 4개 레이어로 구성된 경량 모델로, 타겟 모델에 비해 훨씬 빠르게 토큰을 생성한다.
2단계: MTP로 텍스트 생성
MTP 활성화는 단 한 줄 추가로 완성된다. model.generate 호출 시 assistant_model 파라미터에 드래프터 모델을 전달하면 된다.
messages = [
{
"role": "user",
"content": "Explain the concepts of speculative decoding and MTP in 3 sentences."
}
]
# 입력 준비
input_text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = processor(text=input_text, return_tensors="pt").to(target_model.device)
# assistant_model 파라미터 하나로 MTP 활성화
outputs = target_model.generate(
**inputs,
assistant_model=assistant_model, # 이 한 줄이 전부다
max_new_tokens=256,
do_sample=False,
)
# 응답 디코딩
response = processor.decode(
outputs[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True
)
print(response)내부 동작 흐름은 다음과 같다.
- 드래프터가 N개 토큰을 자동회귀 방식으로 예측
- 타겟 모델이 N개 토큰을 1번의 포워드 패스로 병렬 검증
- 확률이 높은 토큰 수락, 낮은 토큰 거부
- 거부된 토큰 이후는 모두 무효 처리
- 타겟 모델은 항상 최소 1개 토큰을 자체 생성
3단계: 드래프트 토큰 수 조정
드래프터가 한 번에 생성하는 토큰 수는 속도와 정확도 사이의 트레이드오프에 영향을 준다.
| 드래프트 토큰 수 | 특성 |
|---|---|
| 많을 때 (예: 15개) | 수락률 하락 가능성, 수락률이 높을 때 속도 극대화 |
| 적을 때 (예: 3개) | 수락률 높음, 드래프터 속도 이점 감소 |
매번 최적값을 직접 탐색하는 대신, num_assistant_tokens_schedule을 "heuristic"으로 설정하면 런타임에 자동으로 조정된다.
- 모든 토큰 수락 시: 드래프트 수 +2 (드래프터가 정확하므로 더 많이 예측)
- 토큰 거부 발생 시: 드래프트 수 -1 (낭비 방지)
# 초기 드래프트 토큰 수 설정
assistant_model.generation_config.num_assistant_tokens = 4
# 동적 스케줄 활성화 ("heuristic") 또는 고정 ("constant")
assistant_model.generation_config.num_assistant_tokens_schedule = "heuristic"
# MTP로 생성
outputs = target_model.generate(
**inputs,
assistant_model=assistant_model,
max_new_tokens=256,
do_sample=False,
)
response = processor.decode(
outputs[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True
)
print(response)"heuristic" 스케줄은 프롬프트 특성에 따라 최적의 드래프트 토큰 수를 자동으로 찾아가므로, 별도의 튜닝 없이 대부분의 시나리오에서 좋은 성능을 낸다.
마치며
HuggingFace Transformers에서 Gemma 4 MTP를 활성화하는 방법은 매우 단순하다. generate 호출 시 assistant_model 파라미터 하나만 추가하면 된다. 품질은 타겟 모델 그대로 유지하면서 추론 속도를 크게 높일 수 있어, 지연(latency) 최소화가 중요한 온디바이스(on-device) 애플리케이션이나 실시간 서비스에 특히 유용하다.
num_assistant_tokens_schedule = "heuristic" 설정으로 드래프트 토큰 수를 자동 조정하면 사용 케이스별로 별도 튜닝 없이도 최적에 가까운 성능을 얻을 수 있다. gemma 4 모델 라인업과 각 모델의 특성은 Gemma 토픽 페이지에서 확인할 수 있다.
참고 자료
- Gemma 4 Multi-Token Prediction (MTP) using Hugging Face Transformers — Google AI for Developers (2026-05-10)