speculative decoding: SpecInfer

speculative decoding学习笔记:

  • speculative decoding
  • look head机制

一、SpecInfer提出前的一些问题

  1. 串行依赖:传统自回归解码必须逐token生成,GPU利用率不足30%
  2. 内存墙:KV缓存占用显存,长文本场景下并发请求数锐减
    现有方案(如vLLM的PagedAttention)虽优化内存,但未解决计算并行性问题。而SpecInfer提出了一种颠覆性的解决方案——树状推测式推理,将端到端推理速度提升最高3.5倍

二、核心创新:从「线性推测」到「树状推理」

传统推测执行(如DeepMind的Chunkwise并行)仅预测单一路径,成功率仅50-60%。SpecInfer的突破在于:

1. 树状候选空间构建

  • 多路径推测:同时生成多个候选token序列,组织为树结构(图2)
  • 双引擎驱动
    • 扩展引擎:单小模型(SSM)生成Top-k分支(k=5时成功率97%)
    • 融合引擎:多SSM协同预测,自适应提升覆盖范围

2. 拓扑感知的并行验证

输入提示
SSM集群生成候选树
树结构KV缓存
单次LLM前向验证整棵树
多路径验证输出
  • 关键技术突破
    • 树注意力机制:动态屏蔽非法路径,保留合法因果依赖
    • 深度优先缓存更新:共享前缀KV,避免冗余计算

三、关键设计

1. 分布式推理加速(2.8x)

  • 实验对比(图7):LLaMA-65B在8*A100上的表现
    系统 时延/Token 加速比
    HuggingFace TGI 58ms 1.0x
    vLLM 42ms 1.4x
    SpecInfer 15ms 2.8x

2. 卸载推理优化(3.5x)

  • CPU offload场景(OPT-30B单卡):
    • FlexGen:12.3 tokens/s → SpecInfer:43.1 tokens/s

3. 零精度损失的验证算法

  • 多步推测采样(MSS)
def verify_stochastic(, ):
     =# 初始化验证通过的令牌集合
    u = root of token tree   # 指向令牌树的根节点
    
    while u is a non-leaf node:= child(u)  # ▶ u的子节点集合
        
        whileis not empty:
            s ∼ rand()  # 随机选择一个子节点
            r ∼ U(0, 1)  # 生成随机数
            xₛ =[s]    # 获取令牌值
            
            # 验证条件
            if r ≤ P(xₛ | u, Θ_LLM) / P(xₛ | u, Θ_SSMₛ):
                # ▶ 令牌 xₛ 通过验证
                .append(xₛ)
                u = s
                break
            else:
                # ▶ 对残差概率进行标准化
                P(x | u, Θ_LLM) := norm(max(0, P(x | u, Θ_LLM) - P(x | u, Θ_SSMs))).pop(s)
                
                ifis empty:
                    break
        
        ifis empty:
            # ▶ 所有SSM验证失败; 采样下一个令牌
            x_next ∼ P(x | u, Θ_LLM)
            .append(x_next)
    
    return 

算法要点说明:

  1. 核心流程:遍历令牌树并验证节点
  2. 随机选择s ∼ rand(ℋ) 从子节点中随机选择验证目标
  3. 概率验证P(xₛ | u, Θ_LLM)/P(xₛ | u, Θ_SSMₛ) 确定是否接受令牌
  4. 残差处理:验证失败时更新概率分布:
    P(x | u, Θ_LLM) := norm(max(0, P(x | u, Θ_LLM) - P(x | u, Θ_SSMs)))
    
  5. 失败处理:所有子节点验证失败时,直接从大模型采样新令牌
  6. 数学符号
    • Θ_LLM:大语言模型参数
    • Θ_SSMₛ:小型模型参数
    • :验证通过的令牌集合
    • ℋ:当前节点的子节点集合

数学证明(定理4.2):MSS严格等价于原始LLM的概率分布


三、具体设计

1. 系统架构(图6)

  • 动态批处理层:整合多请求的推测树

以下通过一个客服机器人对话场景的具体示例,说明 SpecInfer 的 Continuous Batching 处理机制如何在实际系统中运作。该场景包含 3 个并发用户请求,系统使用 LLaMA-7B 作为大模型(LLM),LLaMA-160M 作为小模型(SSM)。


场景设定

  • 用户A:提问 “如何重置密码?”
  • 用户B:提问 “订单迟迟未发货怎么办?”
  • 用户C:提问 “会员到期如何续费?”
    系统需同时生成回复,目标响应延迟 ≤200ms。

Continuous Batching 处理流程

步骤1:初始请求聚合与树构建

每个用户请求经 SSM 生成 候选 Token 树(宽度=3),树结构如下:

graph TD
    subgraph Batch Tree
    A[用户A:如何重置密码?] --> A1[点击]
    A --> A2[进入]
    A --> A3[选择]
    B[用户B:订单未发货?] --> B1[联系]
    B --> B2[查看]
    B --> B3[申请]
    C[用户C:会员续费?] --> C1[打开]
    C --> C2[进入]
    C --> C3[支付]
    end

关键技术:SSM 为每个请求独立生成候选分支,树节点共享输入前缀(如“如何重置密码?”)的 KV 缓存。


步骤2:并行树验证(单次LLM前向传播)

LLM 一次性验证整棵 Batch Tree 的所有路径(共 3×3×3=27 条路径),通过 树注意力机制 并行计算:

# 输入拼接:所有候选路径合并为批处理张量
input_tokens = [
    “如何重置密码?→点击”, “如何重置密码?→进入”, ...  # 用户A的3条路径
    “订单未发货?→联系”, “订单未发货?→查看”, ...    # 用户B的3条路径
    “会员续费?→打开”, “会员续费?→进入”, ...        # 用户C的3条路径
]

# 单次LLM前向传播验证
output_probs = llm.forward(input_tokens)

# 验证结果(通过概率阈值筛选)
verified_tokens = []
for user_paths in output_probs:
    valid_path = select_path(user_paths, threshold=0.8)  # 保留概率>80%的路径
    verified_tokens.append(valid_path[0])  # 取每个请求的首个通过Token

输出

  • 用户A:“点击”(通过)
  • 用户B:“查看”(通过)
  • 用户C:“打开”(通过)

性能关键:单次前向传播完成所有用户当前步的解码,GPU 利用率达 92%。


步骤3:动态批更新与下一轮调度

根据验证结果更新批次,并动态加入新请求:

new_batch = []
for user, token in verified_tokens:
    new_prompt = user.prompt + token  # 追加已通过Token
    if not is_finished(new_prompt):   # 检查是否生成结束(如遇到句号)
        new_batch.append({
            "prompt": new_prompt,
            "tree": ssm.generate_tree(new_prompt, width=3)  # 生成下一轮树
        })

# 新批次示例:
# 用户A: “如何重置密码?点击” → 新树: [“设置”, “页面”, “重新”]
# 用户B: “订单未发货?查看” → 新树: [“物流”, “状态”, “详情”]
# 用户C: “会员续费?打开” → 新树: [“账户”, “会员”, “中心”]

同时,新请求(用户D)加入下一轮批次

new_batch.append({
    "prompt": "如何退订服务?",
    "tree": ssm.generate_tree("如何退订服务?", width=3)
})

动态性:每个迭代步长(iteration)结束后立即更新批次,新请求无需等待当前批完成。


性能优化效果

指标 传统串行处理 SpecInfer Continuous Batching 提升效果
端到端延迟 420ms 150ms 2.8倍
GPU利用率 28% 92% 3.3倍
吞吐量(tokens/s) 62 240 3.9倍

数据来源:LLaMA-7B + 160M SSM 在 8×A100 上的测试结果。


技术优势解析

  1. 迭代级调度(Iteration-level Scheduling)

    • 单次前向传播 为调度单元,而非完整请求(Request-level)。
    • 新请求在下一迭代步即可加入,无需等待当前批所有用户完成生成。
  2. 树状KV缓存共享

    • 相同前缀(如“如何重置密码?”)的 KV 缓存跨请求复用,减少 40% 显存占用。
  3. 残差概率标准化(Residual Normalization)

    • 验证失败的分支更新概率分布:
      P_{\text{new}}(x) = \frac{\max(0, P_{\text{LLM}}(x) - P_{\text{SSM}}(x))}{\sum \max(0, P_{\text{LLM}} - P_{\text{SSM}})}
      
      确保采样分布与原始 LLM 严格一致。

  • 异构并行引擎
    • SSM层:数据并行(多GPU独立预测)
    • LLM层:张量+流水线并行(Megatron-LM方案)

2. 极致优化技巧

  • CUDA内核融合:将树注意力计算压缩至单次内核启动
  • 前缀共享缓存:树节点共享公共路径KV,内存节省40%

四、应用场景展望
  1. 长文本生成:32K上下文场景,吞吐提升2.6倍
  2. 实时对话系统:端到端时延从850ms→240ms(Alpaca数据集)
  3. 边缘设备推理:小模型引导+大模型验证,降低90%设备要求

论文:SpecInfer: Accelerating Large Language Model Serving with Tree-based Speculative Inference and Verification
技术启示:当「推测执行」遇见「树形结构」,大模型推理的摩尔定律已被改写

你可能感兴趣的:(人工智能,自然语言处理,大模型推理)