speculative decoding学习笔记:
传统推测执行(如DeepMind的Chunkwise并行)仅预测单一路径,成功率仅50-60%。SpecInfer的突破在于:
1. 树状候选空间构建
2. 拓扑感知的并行验证
1. 分布式推理加速(2.8x)
系统 | 时延/Token | 加速比 |
---|---|---|
HuggingFace TGI | 58ms | 1.0x |
vLLM | 42ms | 1.4x |
SpecInfer | 15ms | 2.8x |
2. 卸载推理优化(3.5x)
3. 零精度损失的验证算法
def verify_stochastic(, ):
= ∅ # 初始化验证通过的令牌集合
u = root of token tree # 指向令牌树的根节点
while u is a non-leaf node:
ℋ = child(u) # ▶ u的子节点集合
while ℋ is not empty:
s ∼ rand(ℋ) # 随机选择一个子节点
r ∼ U(0, 1) # 生成随机数
xₛ = ℋ[s] # 获取令牌值
# 验证条件
if r ≤ P(xₛ | u, Θ_LLM) / P(xₛ | u, Θ_SSMₛ):
# ▶ 令牌 xₛ 通过验证
.append(xₛ)
u = s
break
else:
# ▶ 对残差概率进行标准化
P(x | u, Θ_LLM) := norm(max(0, P(x | u, Θ_LLM) - P(x | u, Θ_SSMs)))
ℋ.pop(s)
if ℋ is empty:
break
if ℋ is empty:
# ▶ 所有SSM验证失败; 采样下一个令牌
x_next ∼ P(x | u, Θ_LLM)
.append(x_next)
return
s ∼ rand(ℋ)
从子节点中随机选择验证目标P(xₛ | u, Θ_LLM)/P(xₛ | u, Θ_SSMₛ)
确定是否接受令牌P(x | u, Θ_LLM) := norm(max(0, P(x | u, Θ_LLM) - P(x | u, Θ_SSMs)))
数学证明(定理4.2):MSS严格等价于原始LLM的概率分布
1. 系统架构(图6)
以下通过一个客服机器人对话场景的具体示例,说明 SpecInfer 的 Continuous Batching 处理机制如何在实际系统中运作。该场景包含 3 个并发用户请求,系统使用 LLaMA-7B 作为大模型(LLM),LLaMA-160M 作为小模型(SSM)。
每个用户请求经 SSM 生成 候选 Token 树(宽度=3),树结构如下:
graph TD
subgraph Batch Tree
A[用户A:如何重置密码?] --> A1[点击]
A --> A2[进入]
A --> A3[选择]
B[用户B:订单未发货?] --> B1[联系]
B --> B2[查看]
B --> B3[申请]
C[用户C:会员续费?] --> C1[打开]
C --> C2[进入]
C --> C3[支付]
end
✅ 关键技术:SSM 为每个请求独立生成候选分支,树节点共享输入前缀(如“如何重置密码?”)的 KV 缓存。
LLM 一次性验证整棵 Batch Tree 的所有路径(共 3×3×3=27 条路径),通过 树注意力机制 并行计算:
# 输入拼接:所有候选路径合并为批处理张量
input_tokens = [
“如何重置密码?→点击”, “如何重置密码?→进入”, ... # 用户A的3条路径
“订单未发货?→联系”, “订单未发货?→查看”, ... # 用户B的3条路径
“会员续费?→打开”, “会员续费?→进入”, ... # 用户C的3条路径
]
# 单次LLM前向传播验证
output_probs = llm.forward(input_tokens)
# 验证结果(通过概率阈值筛选)
verified_tokens = []
for user_paths in output_probs:
valid_path = select_path(user_paths, threshold=0.8) # 保留概率>80%的路径
verified_tokens.append(valid_path[0]) # 取每个请求的首个通过Token
输出:
“点击”
(通过)“查看”
(通过)“打开”
(通过)⚡ 性能关键:单次前向传播完成所有用户当前步的解码,GPU 利用率达 92%。
根据验证结果更新批次,并动态加入新请求:
new_batch = []
for user, token in verified_tokens:
new_prompt = user.prompt + token # 追加已通过Token
if not is_finished(new_prompt): # 检查是否生成结束(如遇到句号)
new_batch.append({
"prompt": new_prompt,
"tree": ssm.generate_tree(new_prompt, width=3) # 生成下一轮树
})
# 新批次示例:
# 用户A: “如何重置密码?点击” → 新树: [“设置”, “页面”, “重新”]
# 用户B: “订单未发货?查看” → 新树: [“物流”, “状态”, “详情”]
# 用户C: “会员续费?打开” → 新树: [“账户”, “会员”, “中心”]
同时,新请求(用户D)加入下一轮批次:
new_batch.append({
"prompt": "如何退订服务?",
"tree": ssm.generate_tree("如何退订服务?", width=3)
})
动态性:每个迭代步长(iteration)结束后立即更新批次,新请求无需等待当前批完成。
指标 | 传统串行处理 | SpecInfer Continuous Batching | 提升效果 |
---|---|---|---|
端到端延迟 | 420ms | 150ms | 2.8倍 |
GPU利用率 | 28% | 92% | 3.3倍 |
吞吐量(tokens/s) | 62 | 240 | 3.9倍 |
数据来源:LLaMA-7B + 160M SSM 在 8×A100 上的测试结果。
迭代级调度(Iteration-level Scheduling)
树状KV缓存共享
残差概率标准化(Residual Normalization)
P_{\text{new}}(x) = \frac{\max(0, P_{\text{LLM}}(x) - P_{\text{SSM}}(x))}{\sum \max(0, P_{\text{LLM}} - P_{\text{SSM}})}
确保采样分布与原始 LLM 严格一致。2. 极致优化技巧
论文:SpecInfer: Accelerating Large Language Model Serving with Tree-based Speculative Inference and Verification
技术启示:当「推测执行」遇见「树形结构」,大模型推理的摩尔定律已被改写