推测解码(投机解码)

概述

推测解码是一种优化推理过程的方法,适用于自回归模型。自回归模型可以简单理解为一个“猜字”游戏:基于当前已有的字面,预测下一个 token,然后将新 token 加入上下文,循环进行预测。在某些情况下,可以借助结构更小、速度更快的模型来优化这一过程。

推测解码的实现

推测解码通过训练一个小模型(草稿模型)来提前推测输出,再用原始的大模型进行验证。
如果大模型对某个 token 给出的概率不小于阈值 ( p ),则采纳草稿模型的结果;否则,回退到大模型重新进行推理。

效率提升的关键

虽然验证阶段同样需要经过大模型计算,但效率提升的核心在于并行处理
例如,草稿模型在一次解码中先生成未来 10 个 token,然后大模型可以并行地对这 10 个 token 进行验证,从而显著提高整体推理效率。

小模型与大模型的交互方式

自回归模型的约束在于概率依赖关系,而不是计算方式必须串行
在推测解码中,依赖关系由小模型给出,大模型只需在给定完整前缀的条件下,同时计算多个位置的条件概率。
对于 Transformer 来说,这种并行计算是天然支持的。

自回归模型的概率分解形式

自回归模型的概率分解为:

[
p(y_{1:T}) = \prod_{t=1}^{T} p(y_t \mid y_{<t})
]

该公式只要求第 ( t ) 个 token 的条件分布依赖于前缀 ( y_{<t} ),并没有规定必须“算完第 ( t-1 ) 个 token 才能算第 ( t ) 个 token 的 logits”。
只要前缀已知,多个位置的条件概率是可以同时计算的。

Transformer 模型的并行计算能力

在训练阶段,Transformer 本身就是并行计算所有位置的 logits。
给定完整序列 ( [y_1, y_2, \dots, y_T] ),通过 causal mask 保证第 ( t ) 个位置看不到未来信息,一次 forward 就能计算出所有位置的 logits。

关键点:前缀被提前“假设”出来

在推测解码中:

  1. 小模型先生成:
    [
    \hat{y}{t+1}, \hat{y}{t+2}, \dots, \hat{y}_{t+k}
    ]

  2. 对大模型而言,这些 token 已经是“已知前缀”,它面对的是一个完整序列。

因此,大模型可以像训练时一样:

通过这种方式,推测解码充分利用了 Transformer 的并行计算能力,从而显著提升推理效率。