原创 苏剑林 2025-07-03 22:07 北京
如何一步步走向反哺Transformer的“胜利曲线”?
©PaperWeekly 原创 · 作者 | 苏剑林
单位 | 科学空间
研究方向 | NLP、神经网络
在中文圈,笔者应该算是比较早关注线性 Attention 的了,在 2020 年写首篇相关文章线性Attention的探索:Attention必须有个Softmax吗?时,大家主要讨论的还是 BERT 相关的 Softmax Attention。
事后来看,在 BERT 时代考虑线性 Attention 并不是太明智,因为当时训练长度比较短,且模型主要还是 Encoder,用线性 Attention 来做基本没有优势。对此,笔者也曾撰文线性Transformer应该不是你要等的那个模型表达这一观点。
直到 ChatGPT 的出世,倒逼大家都去做 Decoder-only 的生成式模型,这跟线性 Attention 的 RNN 形式高度契合。同时,追求更长的训练长度也使得 Softmax Attention 的二次复杂度瓶颈愈发明显。
在这样的新背景下,线性 Attention 越来越体现出竞争力,甚至出现了“反哺”Softmax Attention 的迹象。
平方复杂度
首先引入一些记号:
一个 Attention 模型,本质上是一个 的映射。本文主要关心 Causal 场景,这意味着 至多跟 相关。
原则上, 的 d 与 的 d 可以不一致,比如 GAU 和 MLA 便是如此,但将它们简化成同一个并不改变问题本质。
标准的 Softmax Attention,通常是指 Attention is All You Need 所提的 Attention 机制:
这里省略了缩放因子 ,因为它总可以吸收到 里边, 是对第二个维度进行指数归一化,而 是一个下三角阵,称为掩码矩阵,定义为:
其中分母的作用主要是保持数值稳定性,另外就是如果我们给
其中
Softmax Attention 的标准实现需要把
最初的模样
线性 Attention 最早的思路主要是模仿和近似 Softmax Attention,其中最简单的方案是直接去掉
简单起见,我们约定矩阵乘法的优先级高于 Hadamard 积,这样可以省掉一组括号。为什么这个形式是“线性”Attention 的呢?
为了快速理解这一点,我们不妨先考虑去掉
至于 Causal 版(6),我们可以从分量形式理解,写出:
如果我们记括号部分为
由此可见,Causal 形式的 Attention 可以写成一个以
注意这里出现了“线性 RNN”,它是更广义的概念,线性 Attention 属于线性 RNN 的一种,线性 RNN 也单独发展过一段时间,比如之前介绍过的 LRU、SSM 等,但最近比较有竞争力的线性架构都具有线性 Attention 的形式。
早年的线性 Attention 还有一些非常明显的模仿 Softmax Attention 的特点,比如会给式(6)加入分母来归一化,而为了归一化,那么
然而,后来的研究如《The Devil in Linear Transformer》[3] 发现,在序列长度维度归一化并不能完全避免数值不稳定性,倒不如直接事后归一化,如:
而既然不用归一化,那么给
笔者的观点是,加激活函数是大家的自由,不排除加某个激活函数能够调出更好的效果,但加激活函数并不改变线性 Attention 的形式,所以不影响我们的描述,另外就是现有的结果表明,其实不加已经足够好。
花式遗忘门
从式(8)我们可以看出,目前的线性 Attention 本质上就是个
为了缓解这个问题,RetNet [4] 给线性 Attention 引入了遗忘效应:
其中衰减因子
注意,衰减因子在 RetNet 前也有,不过它们多以线性RNN的形式出现,如上一节提到的LRU、SSM 等,RetNet 应该是首次将它跟线性 Attention 结合起来。
加入衰减因子后,模型会倾向于遗忘掉更为久远的历史信息,从而至少保证最近 token 的分辨率,说白了就是跟语言模型特性相符的“就近原则(Recency Bias)”的体现,从而往往能工作得更好。
此外,一个值得关注的细节是 RetNet 还给
尽管给 RNN 加位置编码的操作看上去似乎有点违和,但有些实验比如最近的 TransXSSM [6] 表明,给线性 Attention 加 RoPE 也有一定的正面作用。当然,这可能取决于具体的模型变体和实验设置。
式(10)的一个简单推广是将
后来,DFW [7]、Mamba [8]、Mamba2 [9] 等工作,将它推广成跟输入相关,形成了“data-dependent decay”相关的一系列工作,这跟以往 GRU、LSTM 等非线性 RNN 的“遗忘门(forget gate)”其实已经非常相似了,只不过为了保持模型的线性性,去掉了遗忘门对 State(如
为什么我们偏爱线性 RNN 呢?因为线性 RNN 基本都能找到某种方式来并行训练,这使得它相比 Softmax Attention 更具竞争力——在训练效率和推理效率上都不逊色。
其中,并行化的“通解”是转化为 Prefix Sum [10] 问题然后 Associative Scan,大体思路我们在Google新作试图“复活”RNN:RNN能否再次辉煌?的“并行化”一节也简单介绍过。
然而,“通解”并不是 GPU 高效的,GPU 最高效的是矩阵乘法,所以找到大量使用矩阵乘法的并行算法是最理想的,甚至都不用并行,只要找到充分使用矩阵乘法的 Chunk by Chunk 递归格式,都能明显提高训练效率。
这反过来对模型提出了要求,如只有外积形式的遗忘门才能实现这个目的,典型反例就是 Mamba,它是非外积的遗忘门,无法充分发挥 GPU 的性能,所以才有了后续 Mamba2 和 GLA [11] 等变化。
测试时训练
至此,线性 Attention 从最初的简单模仿 Softmax Attention,到引入静态衰减因子乃至“data-dependent decay”,已经形成了自身的特色并在不少任务上发挥价值。
然而,这些进展多数是靠人工凭经验设计出来的,我们不禁要问:有没有更上层的原则来指导线性 Attention 甚至是一般的序列模型(Token-Mixer)的设计?
对于这个问题,TTT(Test Time Training)[12] 给出了自己的答案,它将序列模型的构建视为一个“在线学习(Online Learning)”问题,并提出用优化器来构建(不一定是线性的)RNN 的做法。
具体来说,它将
这跟 RNN 有什么关系呢?很简单,优化器如 SGD、Adam 等,它们本质上就是一个关于模型参数的 RNN!
其实这个观点并不新鲜,早在 2017 年 Meta Learning 盛行那会就已经有研究人员提出并利用了这点,只不过当时的想法是尝试用 RNN(LSTM)去模拟一个更好的优化器,详情可以参考《Optimization as a Model for Few-Shot Learning》[13]。
正所谓“风水轮流转”,时隔多年 TTT 反过来提出通过优化器来构建 RNN。它的流程是这样的:首先,当前模型参数为
所以,TTT 所实现的 RNN 可以统一地写成:
其中
这个形式可以覆盖非常多的 RNN 模型,比如式(8)和(10)都是它的一个特例:
TTT 原文则致力于探索 mini-batch 下的非线性 RNN,后来的 Titans [14] 则给 TTT 的 SGD 加上了动量,再后面《Test-Time Training Done Right》[15] 则探索了 large-batch 的 TTT 用法,还探索了“TTT + Muon”的组合。
注意,TTT 只是利用优化器来构建 RNN,RNN 以外的参数如
一个更值得思考的问题是:为什么 TTT 可以成为构建 RNN 的“指导原则”呢?
RNN 的核心目标,是将历史数据有效地压缩到一个固定大小的 State 中,而模型参数正好是固定大小的,训练模型某种程度上就相当于把训练数据压缩到模型权重中,TTT 正是利用了它跟 RNN 目标的高度契合性。
说直白一点,如果将 RNN 视为一个压缩任务,TTT 将模型
这样一来,我们就不用花心思构建递归格式了,转而构建模型
除此之外,TTT 用 Online Learning 构建 RNN,意味着所得 RNN 必然非常契合 ICL(In Context Learning)任务,这也是 TTT 作为“指导原则”的优势。
此前《Why Can GPT Learn In-Context? Language Models Implicitly Perform Gradient Descent as Meta-Optimizers》[16] 甚至反过来,将 Softmax Attention 去掉 Softmax 成线性 Attention 来解释它的 ICL 能力,用现在的视角看它就是构造了对应的 TTT 出来。
除旧而迎新
例如,最早的线性 Attention 对应的损失函数是
相比之下,RetNet 往损失函数加入了 L2 正则项,避免了这种风险,从优化角度看也缓解了过拟合的风险,从而得到一个更好的 RNN。
然而,用内积作为损失函数虽然简洁且有一定道理,但它不是直接鼓励
这便是 DeltaNet,这个名字出自《Parallelizing Linear Transformers with the Delta Rule over Sequence Length》[17],更早则是由《Linear Transformers Are Secretly Fast Weight Programmers》[18] 提出。
留意到
如果有需要,我们再把
直观来想,“先减后加”就是先移除模型对
Delta Rule并不新鲜,它又称为 Least Mean Square [20]、Widrow-Hoff Algorithm 等,已经是上个世纪 60 年代的产物了。事实上,这个领域完全新的东西很少,很多改动都可以追溯到某个“上古时期”的工作,目前的努力主要集中在挖掘其中能 Scalable 的部分。
另外需要指出的是,按照时间的顺序,是 DeltaNet 在前,TTT 在后,从 Online Learning 角度理解 RNN,其实在 TTT 之前已经零星地体现在一些工作中,但 TTT 系统地提出了这个“指导原则”,并且将它用于构建新 RNN 模型,所以我们把 TTT 放在前面,使得整个介绍更加流畅自然一些。
有些读者可能疑问:DeltaNet 还算线性 RNN 吗?
答案是肯定的。我们所说的线性 RNN,是指递归公式对 State 变量的依赖关系是线性的,但对输入或
求逆与推广
前面我们说了,线性 RNN 最理想的(即 GPU 高效的)并行算法是充分使用矩阵乘法的形式。为了完成这一目标,我们先将 DeltaNet 写成:
记
最后的等式写成矩阵形式是
这里出现了
进一步地,利用
DeltaNet 之后,Gated DeltaNet(GDN)[21] 进一步地将遗忘门引入到 DeltaNet 之中,这倒是可以预料的变化。Gated DeltaNet 的原始引入方式是:
但个人认为,这个提法其实显式打破了 Delta Rule,更好的提法应该是像 Comba [22] 一样,只乘到第一个
它相当于将损失函数取
即
从理论上来说,Gated DeltaNet 也可以写成 DeltaNet 的形式,因为只需要定义
然后结合
不过,这个结果只有在某些情况下具有理论推导的价值(比如推导下一节的 Attention 矩阵),因为实际计算中,不管怎么参数化,对于足够大的 t,
DeltaNet 之后还有另一个推广 DeltaProduct [23],它是将
不过,就笔者的审美而言,与其像DeltaProduct那样扩展常数倍,还不如像时空之章:将Attention视为平方复杂度的RNN一样尝试平方复杂度的 RNN,看有没有机会超越 Softmax Attention。
反哺进行时
说到超越 Softmax Attention,开头提到,如今的线性 Attention 不仅能与 Softmax Attention 一较高低,甚至开始“反哺”它。这看似不可思议,但细思之下并不难理解。
某种意义上,这些年 Softmax Attention 一直在退步,从 MHA、GQA 到 MQA 都是为了压缩 KV Cache 而做减法。而线性 Attention 没有 KV Cache 问题,所以一直往更好的方向前进。
为了更好看出这一点,我们不妨将前面提到的 Attention 机制都以矩阵形式写出来:
其中:
以及
首先我们需要一种方法把 Softmax Attention 转化为线性 Attention,这个并不难,早在 Transformer升级之路:作为无限维的线性Attention [24] 我们就总结了三种将 Softmax Attention 转化为无限维线性 Attention 的方案。
总之,就是存在一个映射
那接下来的事情就简单了,我们只需将上述表格中的线性 Attention 的
如果
一个更有意思的结果是《Understanding Transformer from the Perspective of Associative Memory》[27] 所提的 DeltaFormer,顾名思义它是 Softmax Attention 的 DeltaNet 版本。将 DeltaNet 的
如果要归一化,我们将
所以 DeltaFormer 相当于先用
此外,DeltaFormer 的这个特点还意味着它跟 MQA 特别搭配,因为
不过,在笔者看来,这种固定系数的叠加可能是“没有免费午餐”,比如笔者的实验结果显示,DeltaFormer 的语言模型损失并无太大变化,这意味着如果某些任务的损失明显降低,必然有另一些任务的损失上升了。
硬核编码术
还有一个值得关注的反哺工作是 PaTH Attention,出自《PaTH Attention: Position Encoding via Accumulating Householder Transformations》[28],它从位置编码的角度将 DeltaNet 反哺到 Softmax Attention。
我们在Transformer升级之路:旋转位置编码的完备性分析指出,对于任何正交矩阵
除了旋转矩阵,还有哪些容易构建的正交矩阵呢?
PaTH 用的是 Householder 矩阵 [29]:设
容易看出,这跟 DeltaNet 中
将
其中
注意求逆的是下三角阵,三角阵有一个重要特性,逆矩阵的对角线元素等于原矩阵对角线元素的倒数,如果是分块三角阵则对角块也满足这个特性,于是我们可以写出:
接下来的变换,写成分量形式可能好理解一些:
这里有几个关键点:比较巧妙的是第4个等号,它利用了
第 6 个等号,当我们分别处理 p,s 两部分求和时,结果是
至此,我们可以把整个(Softmax 之前的)注意力矩阵写出来:
有没有被震惊到?这还没完。直接求逆复杂度是
从位置编码的角度看,PaTH 是 CoPE(Contextual Position Encoding)[31] 的一种,它的位置并不是编号
类似地,FoX 也可以看成是 Contextual 版的 ALIBI。上下文相关的位置信息是当前线性 Attention 的主要特征,也可能是反哺 Softmax Attention 的主要方向。
化简乐无穷
我们不妨再深入点探讨一下 PaTH,这不仅有助于我们了解 PaTH,也能帮助我们更熟悉 DeltaNet,两者本身就是高度相关的。这一节我们从 PaTH 的两个特例入手,它可以帮助我们更好地理解 PaTH 与 DeltaNet 的关联。
第一个特例是
有没有觉得有点熟悉?这刚好就是 DeltaNet 的 Attention 矩阵!从这个特例看来,PaTH 和 DeltaFormer 的区别就在于,DeltaFormer 基于核技巧,给 DeltaNet 的
第二个特例是重新引入
那么
写成矩阵形式就是:
是不是又觉得有点熟悉?其实第二部分就是
也就是用 DeltaNet 给
当然我们也可以考虑放弃前面的推导,即便
剑走偏锋法
最后,我们再看最近的一个同样值得关注的线性 Attention 模型——MesaNet(还有一个大同小异的同期工作 Atlas [33])。
TTT 的 Online Learning 视角告诉我们,DeltaNet 其实就是在用 SGD 优化目标函数
MesaNet 就是利用这个解析解来构建序列模型的,其想法起源于《Uncovering mesa-optimization algorithms in Transformers》[34],高效训练则是由《MesaNet: Sequence Modeling by Locally Optimal Test-Time Training》[35] 实现。
MesaNet 在上述公式基础上给
很明显,
从信号处理的角度看,MesaNet 与 DeltaNet 是 Recursive Least Square [36] 和 Least Mean Square [37] 的区别。
看上去都是优点,为啥笔者会将它归入“剑走偏锋”呢?在笔者看来,MesaNet“成也解析解,败也解析解”,解析解使得它通常优于 DeltaNet,但也给人一种“到此为止”的感觉,因为只要稍变一下就几乎没有机会求得解析解了。
纵观整个数学史,所有依赖于解析解的分支在今天几乎已经都没落了,因为解析解实在太稀罕、太没有代表性了。
从实现上来看,MesaNet 需要求逆的矩阵
如何尽可能低成本地并行计算全体
再就是从理论能力上看,MesaNet 也并非严格优于 DeltaNet。这是因为 MesaNet 的
直观理解就是,MesaNet 会尽力记住全体
总的来说,MesaNet 是一个让人赏心悦目的模型,但解析解也增加了它的复杂性和限制了它的灵活性,留下了不少亟待探索的空间。如果读者想要了解更多基于线性回归来构建序列模型的内容,还可以阅读 TTR [38],它对各种线性回归目标下的序列模型做了详细讨论。
方兴未艾路
本文简要梳理了线性 Attention 的发展脉络,并介绍了部分模型的数学原理。线性 Attention 从模仿 Softmax Attention 起步,逐渐发展出自身特色,如今已成为极具竞争力的序列建模方案,甚至反过来为 Softmax Attention 的发展提供了新思路,这一过程本身充满了趣味性和启发性。
参考文献
[1] https://papers.cool/arxiv/2205.14135
[2] https://papers.cool/arxiv/2103.02143
[3] https://papers.cool/arxiv/2210.10340
[4] https://papers.cool/arxiv/2307.08621
[5] https://papers.cool/arxiv/2501.08313
[6] https://papers.cool/arxiv/2506.09507
[7] https://papers.cool/arxiv/2210.04243
[8] https://papers.cool/arxiv/2312.00752
[9] https://papers.cool/arxiv/2405.21060
[10] https://en.wikipedia.org/wiki/Prefix_sum
[11] https://papers.cool/arxiv/2312.06635
[12] https://papers.cool/arxiv/2407.04620
[13] https://openreview.net/forum?id=rJY0-Kcll
[14] https://papers.cool/arxiv/2501.00663
[15] https://papers.cool/arxiv/2505.23884
[16] https://papers.cool/arxiv/2212.10559
[17] https://papers.cool/arxiv/2406.06484
[18] https://papers.cool/arxiv/2102.11174
[19] https://en.wikipedia.org/wiki/Delta_rule
[20] https://en.wikipedia.org/wiki/Least_mean_squares_filter
[21] https://papers.cool/arxiv/2412.06464
[22] https://papers.cool/arxiv/2506.02475
[23] https://papers.cool/arxiv/2502.10297
[24] https://kexue.fm/archives/8601
[25] https://papers.cool/arxiv/2108.12409
[26] https://papers.cool/arxiv/2503.02130
[27] https://papers.cool/arxiv/2505.19488
[28] https://papers.cool/arxiv/2505.16381
[29] https://en.wikipedia.org/wiki/Householder_transformation
[30] https://kexue.fm/archives/8453
[31] https://papers.cool/arxiv/2405.18719
[32] https://papers.ssrn.com/sol3/papers.cfm?abstract_id=5240330
[33] https://papers.cool/arxiv/2505.23735
[34] https://papers.cool/arxiv/2309.05858
[35] https://papers.cool/arxiv/2506.05233
[36] https://en.wikipedia.org/wiki/Recursive_least_squares_filter
[37] https://en.wikipedia.org/wiki/Least_mean_squares_filter
[38] https://papers.cool/arxiv/2501.12352
更多阅读
#投 稿 通 道#
让你的文字被更多人看到
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。
PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析、科研心得或竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。
📝 稿件基本要求:
• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注
• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题
• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算
📬 投稿通道:
• 投稿邮箱:hr@paperweekly.site
• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者
• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿
△长按添加PaperWeekly小编
🔍
现在,在「知乎」也能找到我们了
进入知乎首页搜索「PaperWeekly」
点击「关注」订阅我们的专栏吧
·
·
·