AI 推理性能调优:Speculative Decoding 投机解码的工程实践
AI 推理性能调优Speculative Decoding 投机解码的工程实践一、自回归解码的延迟困境逐 Token 生成的速度天花板大语言模型的推理过程是自回归的——每次生成一个 Token都要将前面所有 Token 重新送入模型计算一次。这种串行生成方式导致解码阶段的延迟与输出长度线性相关生成 512 个 Token 需要执行 512 次前向传播每次前向传播的延迟约 20-50 毫秒7B 模型单 GPU总延迟高达 10-25 秒。Speculative Decoding投机解码通过猜-验证模式打破了串行瓶颈用一个轻量级的草稿模型Draft Model快速生成多个候选 Token再用目标模型Target Model一次性验证这些候选 Token。如果候选正确相当于一次前向传播生成了多个 Token如果候选错误只需丢弃错误位置之后的 Token。实测中Speculative Decoding 可以将推理速度提升 2-3 倍且不损失输出质量。flowchart LR subgraph 传统自回归解码 T1[Token1] -- T2[Token2] -- T3[Token3] -- T4[Token4] -- T5[Token5] Note1[5次前向传播br/延迟: 5×30ms150ms] -.- T1 end subgraph 投机解码 D[草稿模型br/快速生成5个Token] -- V[目标模型br/一次验证5个Token] V --|3个正确| Accept[接受Token1-3] V --|第4个错误| Reject[丢弃Token4-5] D2[草稿模型br/从Token3重新生成] -- V2[目标模型验证] Note2[2次前向传播生成3个Tokenbr/延迟: 2×30ms60ms] -.- V end二、投机解码的核心机制2.1 草稿-验证流程投机解码分为三个阶段草稿阶段Draft、验证阶段Verify和接受阶段Accept。草稿模型以自回归方式快速生成 K 个候选 Token目标模型对这些候选 Token 执行一次前向传播同时得到每个位置的概率分布。通过比较目标模型和草稿模型的概率分布决定接受或拒绝每个候选 Token。2.2 拒绝采样与概率修正验证的关键在于拒绝采样Rejection Sampling如果目标模型在某个位置的概率高于草稿模型则接受该候选 Token否则以一定概率拒绝并从目标模型的概率分布中重新采样一个 Token。这种机制保证了投机解码的输出分布与原始自回归解码完全一致——不会降低输出质量。sequenceDiagram participant Draft as 草稿模型(7B) participant Target as 目标模型(70B) participant Buffer as 输出缓冲区 Note over Draft: 草稿阶段快速生成5个候选Token Draft-Draft: t1我 → t2认为 → t3这 → t4个 → t5方案 Note over Target: 验证阶段一次前向传播验证5个Token Draft-Target: 提交 [t1,t2,t3,t4,t5] Target-Target: 前向传播得到每个位置的概率 Target-Target: t1: P_target P_draft → 接受 ✅ Target-Target: t2: P_target P_draft → 接受 ✅ Target-Target: t3: P_target P_draft → 接受 ✅ Target-Target: t4: P_target P_draft → 拒绝 ❌ Target-Buffer: 输出 [t1, t2, t3] 从P_target采样t4 Note over Buffer: 一次前向传播生成4个Tokenbr/加速比: 4x三、生产级代码实现3.1 投机解码引擎import torch import torch.nn.functional as F from typing import List, Optional, Tuple import logging logger logging.getLogger(__name__) class SpeculativeDecoder: 投机解码引擎 设计考量 - 草稿模型与目标模型共享 Tokenizer避免编码转换开销 - 验证阶段使用批量前向传播一次验证所有候选 Token - 动态调整草稿长度 K草稿准确率高时增大 K低时减小 K - 温度参数传递草稿模型和目标模型使用相同的采样温度 def __init__( self, draft_model, target_model, tokenizer, draft_length: int 5, max_draft_length: int 8, min_draft_length: int 2, device: str cuda, ): self.draft_model draft_model self.target_model target_model self.tokenizer tokenizer self.draft_length draft_length self.max_draft_length max_draft_length self.min_draft_length min_draft_length self.device device # 统计指标 self._total_tokens 0 self._total_steps 0 self._accepted_tokens 0 torch.no_grad() def generate( self, prompt_ids: List[int], max_new_tokens: int 512, temperature: float 0.0, top_p: float 1.0, ) - List[int]: 使用投机解码生成文本 input_ids torch.tensor([prompt_ids], deviceself.device) generated_ids list(prompt_ids) while len(generated_ids) - len(prompt_ids) max_new_tokens: # Step 1: 草稿模型快速生成 K 个候选 Token draft_tokens, draft_probs self._draft_phase(input_ids, temperature) if not draft_tokens: # 草稿模型无法生成退回标准自回归 next_token self._target_autoregressive_step(input_ids, temperature, top_p) generated_ids.append(next_token) input_ids torch.tensor([generated_ids], deviceself.device) self._total_steps 1 self._total_tokens 1 continue # Step 2: 目标模型验证候选 Token accepted_count, new_token self._verify_phase( input_ids, draft_tokens, draft_probs, temperature, top_p ) # Step 3: 接受正确的候选 Token accepted_tokens draft_tokens[:accepted_count] generated_ids.extend(accepted_tokens) generated_ids.append(new_token) input_ids torch.tensor([generated_ids], deviceself.device) # 更新统计 self._total_steps 1 self._total_tokens accepted_count 1 self._accepted_tokens accepted_count # 动态调整草稿长度 self._adjust_draft_length(accepted_count, len(draft_tokens)) return generated_ids def _draft_phase( self, input_ids: torch.Tensor, temperature: float, ) - Tuple[List[int], List[torch.Tensor]]: 草稿阶段快速生成 K 个候选 Token draft_tokens [] draft_probs [] current_ids input_ids.clone() for _ in range(self.draft_length): outputs self.draft_model(current_ids) next_logits outputs.logits[:, -1, :] # 取最后一个位置 if temperature 0: probs F.softmax(next_logits / temperature, dim-1) else: probs F.softmax(next_logits, dim-1) # 贪心选择temperature0或采样 if temperature 0: next_token next_logits.argmax(dim-1).item() else: next_token torch.multinomial(probs, num_samples1).item() draft_tokens.append(next_token) draft_probs.append(probs.squeeze(0)) # 将新 Token 追加到输入继续生成下一个 current_ids torch.cat([ current_ids, torch.tensor([[next_token]], deviceself.device), ], dim-1) return draft_tokens, draft_probs def _verify_phase( self, input_ids: torch.Tensor, draft_tokens: List[int], draft_probs: List[torch.Tensor], temperature: float, top_p: float, ) - Tuple[int, int]: 验证阶段目标模型一次前向传播验证所有候选 Token Returns: accepted_count: 接受的候选 Token 数量 new_token: 第一个被拒绝位置的重采样 Token或全部接受时的下一个 Token # 构建验证输入原始输入 所有候选 Token draft_tensor torch.tensor([draft_tokens], deviceself.device) verify_ids torch.cat([input_ids, draft_tensor], dim-1) # 目标模型一次前向传播 outputs self.target_model(verify_ids) target_logits outputs.logits # 逐个验证候选 Token start_pos input_ids.shape[-1] - 1 # 从输入的最后一个位置开始 for i, draft_token in enumerate(draft_tokens): pos start_pos i target_prob F.softmax( target_logits[0, pos] / max(temperature, 1e-8), dim-1 ) draft_prob draft_probs[i] # 拒绝采样比较目标模型和草稿模型的概率 p_target target_prob[draft_token].item() p_draft draft_prob[draft_token].item() # 接受条件目标模型概率 草稿模型概率 if p_draft 0 and p_target p_draft: continue # 接受 # 拒绝以概率 (p_target - p_draft) / p_draft 接受 # 简化实现直接比较 if p_draft 0: accept_prob min(1.0, p_target / p_draft) if torch.rand(1).item() accept_prob: continue # 接受 # 拒绝从目标模型分布中采样 rejected_pos i # 修正分布max(0, p_target - p_draft) 归一化 corrected_prob torch.clamp(target_prob - draft_prob, min0) corrected_prob corrected_prob / corrected_prob.sum() if temperature 0: new_token corrected_prob.argmax().item() else: new_token torch.multinomial(corrected_prob, num_samples1).item() return rejected_pos, new_token # 所有候选 Token 都被接受从目标模型采样下一个 Token last_pos start_pos len(draft_tokens) next_prob F.softmax( target_logits[0, last_pos] / max(temperature, 1e-8), dim-1 ) if temperature 0: new_token next_prob.argmax().item() else: new_token torch.multinomial(next_prob, num_samples1).item() return len(draft_tokens), new_token def _target_autoregressive_step( self, input_ids: torch.Tensor, temperature: float, top_p: float, ) - int: 标准自回归步骤草稿模型失败时的降级方案 outputs self.target_model(input_ids) logits outputs.logits[:, -1, :] if temperature 0: logits logits / temperature probs F.softmax(logits, dim-1) return probs.argmax(dim-1).item() def _adjust_draft_length(self, accepted: int, total: int) - None: 根据草稿准确率动态调整草稿长度 acceptance_rate accepted / max(total, 1) if acceptance_rate 0.8 and self.draft_length self.max_draft_length: self.draft_length 1 elif acceptance_rate 0.4 and self.draft_length self.min_draft_length: self.draft_length - 1 def get_stats(self) - dict: 获取投机解码的统计指标 return { total_tokens: self._total_tokens, total_steps: self._total_steps, accepted_tokens: self._accepted_tokens, avg_tokens_per_step: round( self._total_tokens / max(self._total_steps, 1), 2 ), acceptance_rate: round( self._accepted_tokens / max(self._total_tokens, 1), 4 ), current_draft_length: self.draft_length, }四、边界分析与架构权衡4.1 草稿模型的准确率瓶颈投机解码的加速比直接取决于草稿模型的准确率。如果草稿模型的候选 Token 只有 50% 被接受平均每次验证只能生成 1-2 个 Token加速比仅 1.2-1.5x。选择草稿模型的关键是与目标模型同系列但更小如 Qwen2.5-0.5B 作为 Qwen2.5-7B 的草稿模型这样两者的分布更接近接受率更高。4.2 显存开销投机解码需要同时加载草稿模型和目标模型到 GPU。草稿模型通常较小0.5B-2B但仍然需要额外的显存。在显存紧张的场景下可以将草稿模型放在 CPU 上但 CPU 推理的延迟会抵消部分加速收益。4.3 批量推理的兼容性投机解码在单请求场景下效果最好。在批量推理中不同请求的草稿长度和接受位置不同难以高效地批量验证。目前的工程实践是单请求使用投机解码批量推理使用连续批处理Continuous Batching两者不混合使用。五、总结投机解码通过猜-验证模式在不损失输出质量的前提下将自回归解码的延迟降低 2-3 倍。其核心在于选择与目标模型分布接近的草稿模型以及动态调整草稿长度以匹配当前输入的预测难度。落地路线建议第一步选择与目标模型同系列的轻量级模型作为草稿模型第二步实现基本的草稿-验证流程测量接受率和加速比第三步添加动态草稿长度调整优化不同输入场景下的性能第四步集成到推理服务中仅对单请求场景启用投机解码。