推测解码(投机解码)
概述
推测解码是一种优化推理过程的方法,适用于自回归模型。自回归模型可以简单理解为一个“猜字”游戏:基于当前已有的字面,预测下一个 token,然后将新 token 加入上下文,循环进行预测。在某些情况下,可以借助结构更小、速度更快的模型来优化这一过程。
推测解码的实现
推测解码通过训练一个小模型(草稿模型)来提前推测输出,再用原始的大模型进行验证。
如果大模型对某个 token 给出的概率不小于阈值 ( p ),则采纳草稿模型的结果;否则,回退到大模型重新进行推理。
效率提升的关键
虽然验证阶段同样需要经过大模型计算,但效率提升的核心在于并行处理。
例如,草稿模型在一次解码中先生成未来 10 个 token,然后大模型可以并行地对这 10 个 token 进行验证,从而显著提高整体推理效率。
小模型与大模型的交互方式
- 小模型:每次解码生成一个或一段 token(串行)。
- 大模型:一次 forward 可以同时处理多个 token(并行)。
自回归模型的约束在于概率依赖关系,而不是计算方式必须串行。
在推测解码中,依赖关系由小模型给出,大模型只需在给定完整前缀的条件下,同时计算多个位置的条件概率。
对于 Transformer 来说,这种并行计算是天然支持的。
自回归模型的概率分解形式
自回归模型的概率分解为:
[
p(y_{1:T}) = \prod_{t=1}^{T} p(y_t \mid y_{<t})
]
该公式只要求第 ( t ) 个 token 的条件分布依赖于前缀 ( y_{<t} ),并没有规定必须“算完第 ( t-1 ) 个 token 才能算第 ( t ) 个 token 的 logits”。
只要前缀已知,多个位置的条件概率是可以同时计算的。
Transformer 模型的并行计算能力
在训练阶段,Transformer 本身就是并行计算所有位置的 logits。
给定完整序列 ( [y_1, y_2, \dots, y_T] ),通过 causal mask 保证第 ( t ) 个位置看不到未来信息,一次 forward 就能计算出所有位置的 logits。
关键点:前缀被提前“假设”出来
在推测解码中:
-
小模型先生成:
[
\hat{y}{t+1}, \hat{y}{t+2}, \dots, \hat{y}_{t+k}
] -
对大模型而言,这些 token 已经是“已知前缀”,它面对的是一个完整序列。
因此,大模型可以像训练时一样:
- 构造输入:
[
y_{\le t}, \hat{y}{t+1}, \dots, \hat{y}{t+k}
] - 使用 causal mask;
- 一次 forward 同时得到 ( k ) 个位置的 logits。
通过这种方式,推测解码充分利用了 Transformer 的并行计算能力,从而显著提升推理效率。