掘金 人工智能 06月23日 14:17
Speculative Decoding 推测解码方案详解
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

本文深入探讨了推测解码(Speculative Decoding)技术,旨在加速大语言模型(LLM)的推理过程。通过介绍EAGLE、Medusa、Lookahead等方案,文章揭示了如何利用低成本模型预测、并行验证等方法,减少大模型的解码次数,从而提高效率。同时,也讨论了草稿模型、Prompt Lookup Decoding、Jacobi Decoding、Lookahead Decoding和Medusa等不同技术的原理、优缺点及应用,为读者提供了全面的技术视角。

💡推测解码的核心在于提高解码的算术强度,它通过低成本的方式生成候选token,然后并行验证,减少大模型的解码次数。

💡草稿模型是推测解码的关键组成部分,例如EAGLE和DeepMind提出的方案,利用小模型预测token,再由大模型验证,但其性能受限于接受率,且难以获得“又小又准”的模型。

💡Prompt Lookup Decoding通过匹配输入prompt中的n-gram,预测输出token,适用于输出高度重复prompt内容的场景,无需额外模型,提升推理效率。

💡Jacobi Decoding将自回归解码转化为非线性方程组求解,并行预测多个token,但实际应用中面临token位置不准确的挑战。

💡Lookahead Decoding借鉴Jacobi Decoding,收集Jacobi生成路径中的n-gram,并缓存至n-gram池中,通过主模型验证,实现跳跃式推进生成过程。

💡Medusa通过在主模型上添加多个解码头,并行预测未来token,有效缓解了草稿模型常见的分布偏移问题,提升了候选token的准确性和稳定性。

Speculative Decoding 推测解码方案的讲解视频可以在这里观看:www.bilibili.com/video/BV1Q5…

本文是 LLM 推理系列的第 4 篇,介绍 Speculative Decoding 推测解码方案详解,详细介绍了 EAGLE、Medusa、Lookahead 等主流的 Speculative Decoding 方案。

往期文章:

当前,大语言模型(LLM)在推理阶段普遍采用自回归解码策略,其核心特性是逐步串行生成 token,每一步都依赖前一步的输出。这一计算模式导致推理过程在系统层面面临严重的内存带宽瓶颈:每一步前向计算都需要将完整的模型参数从高带宽内存(HBM)加载到加速器缓存,但仅生成一个 token。由于每次只生成一个 token,导致大量的计算资源被闲置,无法充分发挥加速器的算力潜力,最终造成整体推理效率低下。

为解决这一问题,一种加速大语言模型推理的思路是提高解码过程的算术强度(即总浮点运算次数 FLOPs 与数据传输量之间的比值),同时减少解码步骤。基于这一理念,研究者们提出了推测解码/投机解码(Speculative Decoding) 技术。Speculative Decoding 的核心思路如下图所示,首先以低成本的方式(一般来说是用小模型)快速生成多个候选 token,然后通过一次并行验证阶段快速验证多个 token,进而减少大模型的 decode 次数,从而达到加速的目的。

这种方法在实际中之所以有效,是因为大多数时候草稿模型生成的 token 都会被接受。这些 token 本身就容易预测,即使是一个小得多的草稿模型也能准确生成。当这些容易的 token 被接受时,模型就可以快速跳过这些部分。对于那些较难预测的 token,如果目标大模型不同意草稿模型的输出,就会退回到原始的解码速度,甚至略微变慢,因为还需要进行额外的验证计算。

这里举两个例子:

比如 prompt 是 The capital of South Korea is ?,那么输出大概率是 The capital of 开头,这些 token 都很常见,草稿模型很容易预测正确,因此可以被目标模型接受,快速跳过。

在代码提示的场景下,也很好预测,比如 prompt 是:

循环以下数组:nums = [1, 2, 3]

输出大概率是:

for num in nums:    xxx

1 在 vLLM 中使用 Speculative Decoding

在 vLLM 中可以通过 speculative_config 参数来使用 Speculative Decoding。以下代码通过离线模式使用草稿模型进行 Speculative Decoding,每次推测 5 个 token。

from vllm import LLM, SamplingParamsprompts = [    "The future of AI is",]sampling_params = SamplingParams(temperature=0.8, top_p=0.95)llm = LLM(    model="facebook/opt-6.7b",    tensor_parallel_size=1,    speculative_config={        "model": "facebook/opt-125m",        "num_speculative_tokens": 5,    },)outputs = llm.generate(prompts, sampling_params)for output in outputs:    prompt = output.prompt    generated_text = output.outputs[0].text    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

使用以下命令可以以在线模式运行 Speculative Decoding。

vllm serve \  facebook/opt-6.7b \  --speculative-config '{"model": "facebook/opt-125m", "num_speculative_tokens": 5}'

2 早期的 Speculative Decoding

论文:

Fast Inference from Transformers via Speculative Decoding:arxiv.org/abs/2211.17…

Accelerating Large Language Model Decoding with Speculative Sampling:arxiv.org/abs/2302.01…

Speculative Decoding 最早源自 Google 在 2023 年发表的一篇论文,标题为 "Fast Inference from Transformers via Speculative Decoding",该论文首次提出了通过“草稿模型 + 验证模型”的方式来加速 Transformer 模型的推理。

同一时期,DeepMind 也独立发布了另一篇相关论文 "Accelerating Large Language Model Decoding with Speculative Sampling",其背后的核心思想与 Google 的这篇类似。

下图展示了 Speculative Decoding 算法的执行流程。

为了更直观地理解 Speculative Decoding 算法的执行流程,我们以下面的场景为例,走一遍整个生成过程。

假设当前的上下文为:"This apple",词表如下:

vocab = ["This", "apple", "is", "very", "delicious", "bad", "today"]

本轮草稿长度 K=2,也就是说草稿模型 p 将尝试生成两个 token。

第一步:草稿模型生成 token

草稿模型 p 是一个轻量的自回归模型。每一步会计算 logits,并通过 softmax 得到概率分布,从中采样出下一个 token。草稿模型对每个生成 token 的概率(即 softmax 后的值)会被保留,用于之后目标模型的验分过程。

# 基于上下文 "This apple",预测出下一个 token:p_logits_1 = [1.5, 1.8, 2.5, 1.1, 0.3, 0.05, -1.0]  # 草稿模型 logits for "This apple"p_probs_1 = softmax(p_logits_1)                     # 草稿模型计算 softmax 概率# → 草稿模型生成:𝑥̃₁ = "is"# 接着以 "This apple is" 为上下文继续预测:p_logits_2 = [0.7, 0.9, 1.0, 1.3, 0.4, 0.1, -0.8]  # 草稿模型 logits for "This apple is"p_probs_2 = softmax(p_logits_2)                    # 草稿模型计算 softmax 概率# → 草稿模型生成:𝑥̃₂ = "very"

第二步:目标模型并行计算 logits

目标模型 q 是较大但更精确的模型,得益于草稿 token 已经生成完毕,可以并行地对草稿序列的每个位置进行打分(即计算 logits),从而加快验分过程。

q_logits_1 = [1.8, 2.0, 2.2, 1.2, 0.5, 0.1, -0.7]   # 目标模型 q(x | "This apple")q_logits_2 = [0.9, 1.1, 1.0, 1.1, 0.7, 0.2, -0.5]   # q(x | "This apple is")q_logits_3 = [0.5, 0.4, 1.2, 0.7, 2.5, -0.2, 0.0]   # q(x | "This apple is very")q_probs_1 = softmax(q_logits_1)q_probs_2 = softmax(q_logits_2)

第三步:目标模型对草稿 token 验分

对第一个草稿 token:𝑥̃₁ = "is",分别查找两个模型计算出的 softmax 概率:

# 提取 token "is" 对应的概率:p_prob = p_probs_1["is"]  # 草稿模型计算的概率q_prob = q_probs_1["is"]  # 目标模型计算的概率# 假设:p_prob = 0.30q_prob = 0.22

计算接受概率:

acceptprob1=min(1, qprobpprob)=min(1, 0.220.30)=0.733\text{acceptprob}_1 = \min\left(1,\ \frac{qprob}{pprob} \right) = \min(1,\ \frac{0.22}{0.30}) = 0.733

假设采样出的随机数 r1=0.42r_1 = 0.42,因为:

0.42<0.733接受"is"0.42 < 0.733 \quad \Rightarrow \quad 接受 "is"

对第二个草稿 token:𝑥̃₂ = "very"

p_prob = p_probs_2["very"] = 0.28  # 草稿模型计算的概率q_prob = q_probs_2["very"] = 0.24  # 目标模型计算的概率
acceptprob2=min(1, 0.240.28)=0.857\text{acceptprob}_2 = \min\left(1,\ \frac{0.24}{0.28} \right) = 0.857

随机数 r2=0.62r_2 = 0.62,因为:

0.62<0.857接受"very"0.62 < 0.857 \quad \Rightarrow \quad 接受 "very"

第四步:生成奖励 token(bonus token)

因为草稿中的所有 token 都通过了目标模型的验分,我们使用事先计算好的:

q(x | "This apple is very") → q_logits_3

对其进行 softmax 操作,得到概率分布。假设采用贪心解码策略,则选择概率最大的 token。

q_logits_3 = [0.5, 0.4, 1.2, 0.7, 2.5, -0.2, 0.0]# softmax 后 "delicious" 概率最大

因此生成奖励 token:

xₙ₊ₖ₊₁ = "delicious"

3 草稿模型的限制

虽然 Speculative Decoding 通过“猜测-验证”策略,在多个场景下显著提升了解码速度,但草稿模型仍然存在以下限制:

4 Prompt Lookup Decoding

在许多大语言模型(LLM)的应用场景中,例如摘要生成、文档问答、多轮对话和代码编辑,模型的输入(prompt)与输出之间往往存在大量的 n-gram(连续 token 片段)重合。这些重合内容通常是实体名称、常见短语或代码片段,模型在生成时常常直接从输入中复制它们。

Prompt Lookup Decoding 正是利用了这一规律,在生成过程中加速自回归解码。它通过高效的字符串匹配算法(如 KMP)快速定位匹配位置,直接利用 prompt 中已有的信息进行预测,而无需依赖额外的草稿模型生成候选 token。尤其在模型输出高度重复 prompt 内容的场景下,这种方法能显著提升推理效率。当大模型在答案中重复 prompt 的内容时,这种方法效果尤为显著。

vLLM 的 n-gram Speculative Decoding 不仅仅使用原始提示进行匹配,而是使用整个上下文序列(包括原始提示和已生成的标记)。每次生成新标记后,整个上下文序列都会更新,然后用于下一轮的 n-gram 匹配。

上图展示了一个 Prompt Lookup Decoding 的示例:给定一个 prompt,我们会提取其中所有的 2-gram 作为查找键(key),并将它们后面紧跟的三个 token 作为查找值(value)。在生成过程中,我们会检查当前生成的 2-gram 是否与查找表中的某个 key 匹配。如果匹配成功,就使用对应的 value 作为后续的候选 token 进行生成。

在下面这段代码中,我们使用了 n-gram Speculative Decoding,每次最多推测 3 个 token,最多使用 2 个 n-gram 进行匹配。

from vllm import LLM, SamplingParamsprompts = [    "What is the capital of South Korea?",]sampling_params = SamplingParams(temperature=0.8, top_p=0.95)llm = LLM(    model="facebook/opt-6.7b",    tensor_parallel_size=1,    speculative_config={        "method": "ngram", # 使用 n-gram Speculative Decoding        "num_speculative_tokens": 3, # 每次最多推测 3 个 token        "prompt_lookup_max": 2, # 最多使用 2 个 n-gram 进行匹配    },)outputs = llm.generate(prompts, sampling_params)for output in outputs:    prompt = output.prompt    generated_text = output.outputs[0].text    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

5 Jacobi Decoding

论文:Accelerating Transformer Inference for Translation via Parallel Decoding:arxiv.org/pdf/2305.10…

Jacobi 迭代法是一种经典的非线性方程组求解方法。在大语言模型(LLM)推理中,我们也可以将其应用于并行生成 token,且不依赖草稿模型。传统上,自回归解码被视为一个逐步生成 token 的串行过程,如下图左侧所示。而通过对一些公式进行简单重排,这一过程也可以被看作是在求解一个非线性方程组,如下图右侧所示。

通过 Jacobi 迭代法,可以并行地求解该非线性系统中的所有变量 [y1,y2,,ym][y_1, y_2, \dots, y_m],步骤如下:

Jacobi Decoding 之所以可以让大模型并行预测多个 token,是因为它将原本自回归的串行生成过程转化为了一个非线性系统的“并行迭代求解”问题。每一轮都用大模型,在不同位置上并行预测 token,并使用上一轮的结果作为输入,而不是依赖本轮刚生成的内容。

上图展示了这一并行解码过程(也称为 Jacobi Decoding)。Jacobi Decoding 能够在最多 mm 步内求解所有 mm 个变量(即,与自回归解码所需步数相同),因为每一步至少能够保证第一个 token 被正确解码。有时候,多个 token 可能会在一次迭代中同时收敛,从而减少整体解码步数。例如,Jacobi Decoding 在第 4 步中同时预测并接受了两个 token:“computer” 和 “scientist”。

与自回归解码相比,Jacobi Decoding 的每一步在计算上会稍微更昂贵一些,因为它需要在多个 token 上同时进行语言模型的前向计算。但幸运的是,由于 GPU 的并行处理特性,这种额外计算通常不会带来显著的延迟。

Jacobi Decoding 就像:你先把整篇文章草拟出来,然后让老师一段段检查、划掉不对的、保留正确的,然后继续在此基础上写后面的内容。

在实际应用中,Jacobi Decoding 在实现显著的实际加速方面面临不少挑战。尽管它在多轮迭代中确实能够一次生成多个 token,但这些 token 通常难以被准确地放置在正确的位置。即便部分 token 被正确预测,它们也常常在后续的迭代中被新的预测所替代。最终,真正能做到多个 token 同时生成且位置正确的情况非常有限,难以体现并行 Jacobi Decoding 原本追求的效率优势。

6 Lookahead Decoding

论文:Break the Sequential Dependency of LLM Inference Using Lookahead Decoding:arxiv.org/abs/2402.02…

Github:github.com/hao-ai-lab/…

Lookahead Decoding 的灵感来源于 Jacobi Decoding,它将自回归解码视为求解非线性系统的问题,并通过固定点迭代(fixed-point iteration)一次性并行生成多个未来 token。虽然 Jacobi Decoding 中初始化的 token 往往不准确,但其生成轨迹(Jacobi trajectory)中包含的 n-gram 片段可能在后续解码中成为有效的候选。

Lookahead Decoding 正是利用了这一特性:收集 Jacobi 生成路径中的 n-gram,并将其缓存至 n-gram 池中。随后系统在解码时从中挑选可能命中的 n-gram 并通过主模型验证是否可接受,从而跳跃式推进生成过程,显著降低推理延迟。

整个过程中,Lookahead Decoding 不依赖草稿模型,部署简洁,同时借助前瞻生成与并行验证两个分支,在单步内完成多个候选的生成与验证,最大化利用了原本未被自回归解码充分使用的计算资源。

7 Medusa

论文:Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads:arxiv.org/abs/2401.10…

Github:github.com/FasterDecod…

相比于引入一个独立的草稿模型,Medusa 提出了一种更紧凑且高效的设计思路:直接在原始主干模型上进行扩展。其核心做法是在主模型的最后一个隐藏层上添加多个轻量级解码头(Medusa Heads),每个解码头并行预测未来不同位置的 token。Medusa Heads 是在原始主干模型(backbone model)基础上一起训练得到的。

每个预测头的结构极其简洁,仅由一层构成,与主干模型的输出层相似,因此不会增加推理服务系统的复杂度。同时,预测头的输出分布与主模型高度一致,有效缓解了草稿模型常见的分布偏移问题,显著提升了候选 token 的准确性与稳定性。

训练时,主干模型可以保持冻结状态(称为 MEDUSA-1),也可以与预测头一起联合训练(称为 MEDUSA-2)。这种设计使得即使只使用单个 GPU,也能对大型模型进行高效微调,充分利用主干模型已学到的强大特征表示。

在推理阶段,每个预测头会针对其对应的位置生成多个 top 候选 token。然后,这些候选会被组合成多个候选序列,并通过树状注意力机制(tree attention) 并行处理。最后一步是采用典型接受策略(typical acceptance)来筛选出合理的生成路径,并将被接受的最长候选前缀作为下一轮解码的起点。通过同时接受更多的 token,可以减少所需的解码步数,从而提升整个解码过程的效率。

接下来,让我们深入了解 Medusa 的 3 个组成部分:Medusa heads、tree attention 和 typical acceptance。

7.1 Medusa heads

Medusa 之所以能够一次性解码多个 token,核心在于它在原始语言模型的最后隐藏层输出之上,附加了多个解码头(Medusa heads)。每个解码头负责预测序列中不同偏移位置的未来 token。这些解码头结构类似于原始模型的语言模型头,通常是单层前馈神经网络,可以独立地为其目标位置生成 token 预测。这种设计使得模型在一次前向传播中即可并行预测多个后续 token,而不再像传统自回归生成方式那样必须逐步生成单个 token,从而显著提升解码效率。

具体来说:

为了使 MEDUSA 的多个解码头具备良好的预测能力,需要对其进行训练。根据不同的应用场景和资源条件,可以选择不同的训练方式:

MEDUSA-1:冻结主干,仅训练解码头

在该方案中,原始模型的主干(包括原有的解码头)保持冻结,仅训练新增的 MEDUSA 解码头。适用于计算资源有限,或希望完全保留原模型性能的场景。

该方式还可以结合 QLoRA 对解码头进行参数高效微调,从而进一步降低内存和计算资源的消耗。

实验表明,MEDUSA-1 在 Vicuna-7B 上可实现 2.18× 的推理速度提升

MEDUSA-2:联合训练主干与解码头

该方案对原模型主干和 MEDUSA 解码头进行联合训练。相比于 MEDUSA-1,虽然资源开销更大,但也能更充分地发挥多个解码头在推理加速方面的潜力。

MEDUSA-2 通过联合训练,使 MEDUSA 解码头的分布与主干模型保持一致,有效缓解分布漂移问题,显著提升预测准确性

适用于计算资源充足,或从 base 模型进行 SFT(监督微调)的场景。实验证明,MEDUSA-2 在 Vicuna-7B 上可实现 2.83× 的速度提升,优于 MEDUSA-1。

为保证训练效果和稳定性,MEDUSA-2 引入了以下关键训练策略:

由于主干模型已经预训练充分、较为稳定,而 MEDUSA 解码头是新增模块、仍处于早期学习阶段,因此为两者设置不同的学习率是合理的选择:

这样既能提升训练效率,又能保护主干模型不被破坏。

在训练初期,MEDUSA heads 的损失较大,可能导致梯度过大,从而干扰主干模型的参数更新。为此,采用两阶段训练流程:

具体可采用两种方式:

训练数据的获取

7.2 Tree Attention

Medusa 在推理时,每个解码头不仅输出一个最可能的 token,而是采样多个 top-k 候选 token,并将它们组合成一棵树状结构的候选序列集合。作者在实验中发现,尽管预测“下下一个” token 的 top-1 准确率仅约为 60%,但 top-5 的准确率却超过了 80%。这一显著提升表明,如果能够合理利用这些高概率预测结果,便可以在每个解码步骤中生成更多 token,从而显著加速推理过程。为了在 不增加大模型调用次数 的前提下高效验证多个候选路径,Medusa 引入了 Tree Attention 机制。

以一个具体例子说明:

假设当前有两个 Medusa 解码头:

这就构成了 2 × 3 = 6 条可能路径:

路径 1: It → is  路径 2: It → '  路径 3: It → the  路径 4: I → is  路径 5: I → '  路径 6: I → the

我们将原始输入看作根节点,每条路径从根节点出发,经过 head1 的一个 token,再接上 head2 的一个 token,形成一个二层的候选树。

为了避免对每条路径逐个调用大模型进行验证,Medusa 使用 Tree Attention,将所有候选路径中的 token 扁平展开为一个连续序列:

["It", "I", "is", "'", "the", "is", "'", "the"]

接着,构建 Tree Mask,确保注意力机制只发生在逻辑上有因果关系的 token 之间:

所以,利用 Tree Attention 就实现了一次大模型调用就把所有的路径都验证了的效果。考虑到真实场景,大模型的推理瓶颈往往在内存读取,而不是计算。所以适当地增加计算量,基本不会影响推理耗时,反而通过 Tree Attention 有可能得到更长的 draft。

7.3 Typical acceptance

在传统的 Speculative Decoding 中,最常用的候选验证方法是 Rejection Sampling(拒绝采样)。它的做法是:草稿模型(通常是一个小模型)先生成多个候选 token,原始大模型逐个验证这些 token 是否“符合自己概率分布”,不符合就全部丢弃。这种方式虽然准确,能保持输出分布和原始模型一致,但也存在一个严重问题——效率低。尤其当采样温度(temperature)较高时,草稿模型会生成更多富有创意、但偏离主流分布的候选 token。此时原始模型很容易拒绝它们,导致解码步骤缩短、重复计算,效率大大下降。

为了提高整体推理效率,Medusa 提出了一种更宽容、更高效的验证策略 —— Typical Acceptance(典型接受)。该策略从 截断采样(Truncation Sampling) 的研究中汲取灵感,目标是扩大原始模型可接受的候选范围。具体而言,Medusa 不再要求草稿模型生成的 token 与原始模型的分布完全对齐,而是根据原始模型的预测概率设定一个接受阈值:只要某个候选 token 的概率超过该阈值,就可以接受它以及它之前的所有 token(即其 prefix)。在这些被接受的 token 中,Medusa 使用贪婪策略(Greedy)选取 top-k,作为最终的候选输出。

它的核心思想是:Medusa 不强求草稿模型生成的 token 与原始模型的分布完全一致,只要这些 token 不是极不可能的结果,就可以被接受。

假设当前上下文是:“The weather is”,草稿模型(温度较高)生成的下一个 token 是:

"fun"

而原始模型的概率分布如下:

Token概率
"nice"0.35
"bad"0.30
"cold"0.15
"fun"0.08
"wet"0.06
"?"0.03
"sad"0.02
"quick"0.01

我们设定一个累计概率阈值为 90%,从高到低累加:

Rejection Sampling 是怎么做的:

Typical Acceptance 是怎么做的:

7.4 Medusa 的缺点

每个 Medusa Head 是独立执行的,即“预测下下个 token”不会依赖于“下一个 token” 的预测结果,缺乏序列间的依赖性,导致生成效果不佳、草稿接受率较低。

8 EAGLE

论文:EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty:arxiv.org/pdf/2401.15…

EAGLE-2: Faster Inference of Language Models with Dynamic Draft Trees:arxiv.org/pdf/2406.16…

EAGLE-3: Scaling up Inference Acceleration of Large LanguageModels via Training-Time Test:arxiv.org/pdf/2503.01…

Github:github.com/SafeAILab/E…

EAGLE(Extrapolation Algorithm for Greater Language-model Efficiency)提出了一种新颖的推测采样框架,其核心创新在于:

EAGLE 经第三方评估认证,是目前最快的 Speculative Decoding 方法:

在 vLLM 中,可以通过如下方式启用 EAGLE。需要注意的是,EAGLE 的草稿模型需在不启用张量并行的模式下运行,而目标模型则可以正常使用张量并行以提升推理效率。

from vllm import LLM, SamplingParamsprompts = [    "The future of AI is",]sampling_params = SamplingParams(temperature=0.8, top_p=0.95)llm = LLM(    model="meta-llama/Meta-Llama-3-8B-Instruct",    tensor_parallel_size=4, # 张量并行    speculative_config={        "model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",        "draft_tensor_parallel_size": 1, # 无张量并行    },)outputs = llm.generate(prompts, sampling_params)for output in outputs:    prompt = output.prompt    generated_text = output.outputs[0].text    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

8.1 核心思路

EAGLE 的作者提出了两个核心观点:

核心观点一:在特征层进行自回归比在 token 层进行自回归更简单

核心观点二:采样过程中的不确定性限制了特征预测的性能

8.2 实现细节

EAGLE 的整体框架如下:

草稿阶段(Drafting Phase):

验证阶段(Verification Phase):

下图展示了 EAGLE 如何通过特征层级的自回归模型高效生成草稿 token 序列。整个 Draft Model(草稿模型)主要由三部分组成:Embedding 层Autoregression HeadLM Head。其中,Embedding 和 LM Head 都复用目标 LLM 的参数,无需训练;唯一需要训练的是 Autoregression Head,由一个全连接层(FC)和一个 decoder 层组成。

整体流程如下:

    假设当前已生成的 token 为 "How", "can", "I"。从原始 LLM 中取出 token "How""can" 在最后一层的输出特征向量,分别记作 fhowf_{\text{how}}fcanf_{\text{can}},以及 token "can""I" 在输入端的 embedding 向量,分别为 ecane_{\text{can}}eIe_{\text{I}}。将 fhowf_{\text{how}}ecane_{\text{can}} 拼接,fcanf_{\text{can}}eIe_{\text{I}} 拼接,构成输入序列。

    拼接后的序列输入到 Autoregression Head,经过降维和解码器处理,预测出 token "I" 的特征向量 fIf_{\text{I}}。然后将 fIf_{\text{I}} 输入到目标 LLM 的 LM Head,得到下一个 token 的分布,并从中采样出 token,例如 "make""help"

    将采样得到的新 token(如 "make""help")送入下一轮 forward,提取其 embedding 向量(如 emakee_{\text{make}}ehelpe_{\text{help}}),与上一轮的预测特征 fIf_{\text{I}} 拼接,作为下一步的输入,继续预测下一步的特征(如 fmakef_{\text{make}}fhelpf_{\text{help}}),再由 LM Head 映射为 token(如 "a""our""with""you")。

    每轮 forward 并非只预测一个 token,而是基于当前所有路径并行扩展多个分支。例如:"make" 预测出 "a""our""help" 预测出 "with""you"。第三轮继续扩展为 "the""your""to""feel",形成一棵 token 草稿树。

补充一点,第一次前向传播无法加速,因为需要通过一次前向传播才能得到后续 EAGLE 所需要的特征。

为了一次验证多个序列,EAGLE 采用了 Tree Attention 来生成树状结构的草稿,这样可以在一个前向传播过程中生成多个 token。

8.3 Speculative Sampling、Lookahead、Medusa、EAGLE 对比

EAGLE 论文中用下图对比了 Speculative Sampling、Lookahead、Medusa、EAGLE 不同方法生成草稿的对比。

9 总结

推测解码(Speculative Decoding)是一种通过预生成多个候选 token 并并行验证以加速 LLM 推理的技术,旨在突破传统自回归解码中的内存带宽瓶颈。本文系统介绍了从早期草稿模型方法、Prompt Lookup 到 Jacobi Decoding、Lookahead、Medusa,再到当前速度领先的 EAGLE 等多种方案。尽管各方案实现方式不同,它们的共同目标是提升解码效率、降低推理延迟,同时保证生成文本的质量。

10 附录

10.1 什么是 Logits?

在大语言模型(LLM)中,logits 指的是模型输出层的原始得分向量,这些数值是未经过归一化处理的实数,表示模型对每个 token 作为下一个生成 token 的“置信度”或“倾向性”。具体来说:

# 假设词表:vocab = ["This", "apple", "is", "very", "delicious", "bad", "today"]# 模型输出的 logits(一个向量,对应词表中每个 token 的得分):logits = [1.8, 2.0, 2.5, 1.2, 0.5, 0.1, -0.7]
softmax(xi)=exijexj\text{softmax}(x_i) = \frac{e^{x_i}}{\sum_j e^{x_j}}

计算每个元素的指数:

exp(logits)=[e1.8, e2.0, e2.5, e1.2, e0.5, e0.1, e0.7][6.05, 7.39, 12.18, 3.32, 1.65, 1.11, 0.50]\exp(\text{logits}) = [e^{1.8},\ e^{2.0},\ e^{2.5},\ e^{1.2},\ e^{0.5},\ e^{0.1},\ e^{-0.7}]\approx [6.05,\ 7.39,\ 12.18,\ 3.32,\ 1.65,\ 1.11,\ 0.50]

求指数和:

i=17exi=6.05+7.39+12.18+3.32+1.65+1.11+0.50=32.20\sum_{i=1}^{7} e^{x_i} = 6.05 + 7.39 + 12.18 + 3.32 + 1.65 + 1.11 + 0.50 = 32.20

最终 softmax 概率为:

softmax(xi)=exijexj[6.0532.20, 7.3932.20, 12.1832.20, 3.3232.20, 1.6532.20, 1.1132.20, 0.5032.20][0.188, 0.229, 0.378, 0.103, 0.051, 0.034, 0.016]\text{softmax}(x_i) = \frac{e^{x_i}}{\sum_j e^{x_j}} \Rightarrow\left[\frac{6.05}{32.20},\ \frac{7.39}{32.20},\ \frac{12.18}{32.20},\ \frac{3.32}{32.20},\ \frac{1.65}{32.20},\ \frac{1.11}{32.20},\ \frac{0.50}{32.20}\right]\approx [0.188,\ 0.229,\ 0.378,\ 0.103,\ 0.051,\ 0.034,\ 0.016]
TokenLogitSoftmax概率
"This"1.80.188
"apple"2.00.229
"is"2.50.378 ✅(最高)
"very"1.20.103
"delicious"0.50.051
"bad"0.10.034
"today"-0.70.016

总的来说,logits 是 LLM 中预测下一个词的“未归一化概率得分”,是生成过程中的关键中间表示,决定了模型如何选择下一个输出 token。logits 的长度等于词表(vocabulary)的大小。

10.2 雅可比迭代法(Jacobi method)

雅可比迭代法(Jacobi method)是一种用于求解线性方程组的迭代算法。它通过迭代地更新方程组中每个未知数的近似值,逐步逼近真实的解。这种方法特别适用于大型、稀疏的线性方程组。

雅可比迭代法的核心思想是将复杂的线性方程组分解为简单的对角部分和非对角部分,然后利用当前迭代值计算下一次迭代值,实现并行计算。当矩阵满足对角占优条件时(即对角元素的绝对值大于同行其他元素绝对值之和),不论初始值如何选取,迭代过程都会收敛到唯一解,因为每次迭代都会使误差逐步减小。

10.2.1 公式

给定一个 n×nn \times n 的线性方程组:

Ax=bAx = b

其中:

A=[a11a12a1na21a22a2nan1an2ann],x=[x1x2xn],b=[b1b2bn]A = \begin{bmatrix}a_{11} & a_{12} & \cdots & a_{1n} \\a_{21} & a_{22} & \cdots & a_{2n} \\\vdots & \vdots & \ddots & \vdots \\a_{n1} & a_{n2} & \cdots & a_{nn}\end{bmatrix}, \quadx = \begin{bmatrix}x_1 \\x_2 \\\vdots \\x_n\end{bmatrix}, \quadb = \begin{bmatrix}b_1 \\b_2 \\\vdots \\b_n\end{bmatrix}

将矩阵 AA 分解为对角矩阵部分 DD 与其余部分 RR

A=D+R其中D=[a11000a22000ann],R=[0a12a1na210a2nan1an20]A = D + R \quad 其中 \quad D = \begin{bmatrix}a_{11} & 0 & \cdots & 0 \\0 & a_{22} & \cdots & 0 \\\vdots & \vdots & \ddots & \vdots \\0 & 0 & \cdots & a_{nn}\end{bmatrix}, \quadR = \begin{bmatrix}0 & a_{12} & \cdots & a_{1n} \\a_{21} & 0 & \cdots & a_{2n} \\\vdots & \vdots & \ddots & \vdots \\a_{n1} & a_{n2} & \cdots & 0\end{bmatrix}

此时线性方程可以变形为:

Dx=bRxDx = b - Rx

即将原方程拆成两部分,方便利用已知的 xx 去迭代求解。

迭代表达式:

x(k+1)=D1(bRx(k))x^{(k+1)} = D^{-1}(b - Rx^{(k)})

表示使用当前第 kk 次迭代的结果,去计算第 k+1k+1 次的近似解。

也可以写为单元素更新公式,可以直接看到第 ii 个未知量如何受其他分量影响:

xi(k+1)=1aii(bijiaijxj(k)),i=1,2,,nx_i^{(k+1)} = \frac{1}{a_{ii}} \left( b_i - \sum_{j \neq i} a_{ij} x_j^{(k)} \right), \quad i = 1, 2, \ldots, n

10.2.2 具体例子

以一个 3×33 \times 3 线性方程组为例:

{4x1+x2+x3=6x1+5x2+x3=7x1+x2+6x3=8\begin{cases}4x_1 + x_2 + x_3 = 6 \\x_1 + 5x_2 + x_3 = 7 \\x_1 + x_2 + 6x_3 = 8\end{cases}

用矩阵形式表示为:

A=[411151116],b=[678]A = \begin{bmatrix} 4 & 1 & 1 \\ 1 & 5 & 1 \\ 1 & 1 & 6 \end{bmatrix}, \quadb = \begin{bmatrix} 6 \\ 7 \\ 8 \end{bmatrix}

应用雅可比迭代法的单元素更新公式:

对于 x1x_1

x1(k+1)=1a11(b1j1a1jxj(k))=14(61x2(k)1x3(k))x_1^{(k+1)} = \frac{1}{a_{11}} \left( b_1 - \sum_{j \ne 1} a_{1j} x_j^{(k)} \right)= \frac{1}{4} \left( 6 - 1 \cdot x_2^{(k)} - 1 \cdot x_3^{(k)} \right)

对于 x2x_2

x2(k+1)=1a22(b2j2a2jxj(k))=15(71x1(k)1x3(k))x_2^{(k+1)} = \frac{1}{a_{22}} \left( b_2 - \sum_{j \ne 2} a_{2j} x_j^{(k)} \right)= \frac{1}{5} \left( 7 - 1 \cdot x_1^{(k)} - 1 \cdot x_3^{(k)} \right)

对于 x3x_3

x3(k+1)=1a33(b3j3a3jxj(k))=16(81x1(k)1x2(k))x_3^{(k+1)} = \frac{1}{a_{33}} \left( b_3 - \sum_{j \ne 3} a_{3j} x_j^{(k)} \right)= \frac{1}{6} \left( 8 - 1 \cdot x_1^{(k)} - 1 \cdot x_2^{(k)} \right)

假设初始值为:

x(0)=[000]x^{(0)} = \begin{bmatrix}0 \\0 \\0\end{bmatrix}

进行迭代计算:

第一次迭代:

x1(1)=14(600)=1.5x_1^{(1)} = \frac{1}{4}(6 - 0 - 0) = 1.5
x2(1)=15(700)=1.4x_2^{(1)} = \frac{1}{5}(7 - 0 - 0) = 1.4
x3(1)=16(800)=1.333x_3^{(1)} = \frac{1}{6}(8 - 0 - 0) = 1.333

所以:

x(1)=[1.51.41.333]x^{(1)} = \begin{bmatrix}1.5 \\1.4 \\1.333\end{bmatrix}

第二次迭代:

x1(2)=14(61.41.333)=3.26740.817x_1^{(2)} = \frac{1}{4}(6 - 1.4 - 1.333) = \frac{3.267}{4} \approx 0.817
x2(2)=15(71.51.333)=4.16750.833x_2^{(2)} = \frac{1}{5}(7 - 1.5 - 1.333) = \frac{4.167}{5} \approx 0.833
x3(2)=16(81.51.4)=5.160.85x_3^{(2)} = \frac{1}{6}(8 - 1.5 - 1.4) = \frac{5.1}{6} \approx 0.85

所以:

x(2)=[0.8170.8330.85]x^{(2)} = \begin{bmatrix}0.817 \\0.833 \\0.85\end{bmatrix}

继续迭代,最终会收敛到精确解:

x=[111]x = \begin{bmatrix}1 \\1 \\1\end{bmatrix}

可以验证:

这个例子展示了雅可比迭代法如何通过反复应用单元素更新公式,逐步逼近三维线性方程组的解。

10.2.3 更通俗的例子

假设你、小明、小红三人一起玩一个猜数字游戏,每人猜一个数字,但你们的数字必须满足以下规则:

你的数字

x=6x小明x小红4x_{\text{你}} = \frac{6 - x_{\text{小明}} - x_{\text{小红}}}{4}

小明的数字

x小明=7xx小红5x_{\text{小明}} = \frac{7 - x_{\text{你}} - x_{\text{小红}}}{5}

小红的数字

x小红=8xx小明6x_{\text{小红}} = \frac{8 - x_{\text{你}} - x_{\text{小明}}}{6}

也就是说,每个人的数字都需要参考另外两个人的数字。

游戏过程(详细迭代示例)

第 0 轮(初始猜测)

刚开始,你们都没有头绪,于是都从 0 开始猜:

玩家猜测数字
0
小明0
小红0

第 1 轮

每个人根据上轮其他人的猜测更新自己的数字:

你的新数字

x(1)=6004=1.5x_{\text{你}}^{(1)} = \frac{6 - 0 - 0}{4} = 1.5

小明的新数字

x小明(1)=7005=1.4x_{\text{小明}}^{(1)} = \frac{7 - 0 - 0}{5} = 1.4

小红的新数字

x小红(1)=8006=1.333x_{\text{小红}}^{(1)} = \frac{8 - 0 - 0}{6} = 1.333

第 2 轮

继续迭代,用第1轮的结果更新:

你的新数字

x(2)=61.41.3334=0.817x_{\text{你}}^{(2)} = \frac{6 - 1.4 - 1.333}{4} = 0.817
x小明(2)=71.51.3335=0.833x_{\text{小明}}^{(2)} = \frac{7 - 1.5 - 1.333}{5} = 0.833
x小红(2)=81.51.46=0.85x_{\text{小红}}^{(2)} = \frac{8 - 1.5 - 1.4}{6} = 0.85

第 3 轮

继续迭代,用第2轮的结果更新:

你的新数字

x(3)=60.8330.854=1.079x_{\text{你}}^{(3)} = \frac{6 - 0.833 - 0.85}{4} = 1.079

小明的新数字

x小明(3)=70.8170.855=1.067x_{\text{小明}}^{(3)} = \frac{7 - 0.817 - 0.85}{5} = 1.067

小红的新数字

x小红(3)=80.8170.8336=1.058x_{\text{小红}}^{(3)} = \frac{8 - 0.817 - 0.833}{6} = 1.058

第 4 轮

继续迭代,用第3轮的结果更新:

x(4)=61.0671.0584=0.969x_{\text{你}}^{(4)} = \frac{6 - 1.067 - 1.058}{4} = 0.969
x小明(4)=71.0791.0585=0.973x_{\text{小明}}^{(4)} = \frac{7 - 1.079 - 1.058}{5} = 0.973
x小红(4)=81.0791.0676=0.976x_{\text{小红}}^{(4)} = \frac{8 - 1.079 - 1.067}{6} = 0.976

持续进行迭代(直到收敛)

随着迭代次数增加,你们的猜测会越来越稳定,最终逐渐收敛到:

玩家最终猜测(趋近值)
1.0
小明1.0
小红1.0

此时,三个人的数字几乎不会再变动,游戏可以结束。

为什么游戏可以收敛到正确答案?

这种猜数字游戏不仅仅是个数学游戏,还能反映现实世界中很多场景,比如:

Jacobi 迭代法的核心,就是不断应用简单的局部更新规则(局部信息),最终达到全局平衡状态(方程组解)。

10.3 Jacobi Decoding 详解

假设目标输入:

x = ["Alan", "Turing"]

最终想生成的理想 token 序列是:

["who", "was", "a"]

Step 0:初始猜测

假设给定一个初始不太靠谱的初始猜测序列:

y(0)=["the","computer","engineer"]\mathbf{y}^{(0)} = ["the", "computer", "engineer"]

Step 1:根据 y(0)\mathbf{y}^{(0)} 并行计算 y(1)\mathbf{y}^{(1)}

我们根据 y(0)\mathbf{y}^{(0)} 的前缀依次预测每个 token:

因此:

y(1)=["who","is","pioneer"]\mathbf{y}^{(1)} = ["who", "is", "pioneer"]

Step 2:使用 y(1)\mathbf{y}^{(1)} 继续迭代

此轮中,所有预测都基于 y(1)\mathbf{y}^{(1)}

得到:

y(2)=["who","was","a"]\mathbf{y}^{(2)} = ["who", "was", "a"]

与目标完全一致,收敛!

流程总结:

步骤y₁y₂y₃说明
y⁰thecomputerengineer初始猜测,完全错误
whoispioneer第一轮,部分修正
whowasa第二轮,完全收敛

11 参考资料

欢迎关注

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

Speculative Decoding LLM 推测解码 Medusa Jacobi Decoding
相关文章