原创 苏剑林 2025-08-07 23:22 北京
ODE不香了?
©PaperWeekly 原创 · 作者 | 苏剑林
单位 | 科学空间
研究方向 | NLP、神经网络
众所周知,生成速度慢是扩散模型一直以来的痛点,而为了解决这个问题,大家可谓“八仙过海,各显神通”,提出了各式各样的解决方案,然而长久以来并没一项工作能够脱颖而出,成为标配。什么样的工作能够达到这个标准呢?在笔者看来,它至少满足几个条件:
1. 数学原理清晰,能够揭示出快速生成的本质所在;
2. 能够单目标从零训练,不需要对抗、蒸馏等额外手段;
3. 单步生成接近 SOTA,可以通过增加步数提升效果。
根据笔者的阅读经历,几乎没有一项工作能同时满足这三个标准。然而,就在前段时间,arXiv 出了一篇《Mean Flows for One-step Generative Modeling》[1](简称 “MeanFlow”),看上去非常有潜力。接下来,我们将以此为契机,讨论一下相关思路和进展。
现有思路
扩散模型的生成加速工作已经有非常多,本博客前面也简单介绍过一些。总的来说,加速思路大体上可以分为三类。
第一,将扩散模型转化为 SDE/ODE,然后研究更高效的求解器,代表作是 DPM-Solver [2] 及其一系列后续改进。
然而,这个思路通常只能将生成的 NFE(Number of Function Evaluations)降到 10 左右,再低就会明显降低生成质量。这是因为求解器的收敛速度通常都是正比于步长的若干次方,当 NFE 很小时步长就无法很小,所以收敛不够快以至于没法用。
第二,通过蒸馏将训练好的扩散模型转化为更少步数的生成器,由此衍生出来的工作和方案也非常多,我们此前介绍过其中的一种名为 SiD [3] 的方案。
蒸馏算是比较常规和通用的思路,但缺点也是共同的,即需要额外的训练成本,并非从零训练的方案。有些工作为了蒸馏到单步生成器,还加上了对抗训练等多重优化策略,整个方案往往过于复杂。
第三,基于一致性模型(Consistency Model,CM),包括我们在《生成扩散模型漫谈(二十八):分步理解一致性模型》[4] 简单介绍的 CM、它的连续版本 sCM [5] 以及 CTM [6] 等。
CM 是自成一派的思路,可以从零训练得到 NFE 很小的模型,也可以用于蒸馏,但 CM 的目标依赖于 EMA 或者 stop_gradient 运算,意味着它耦合了优化器动力学,这通常给人一种说不清道不明的感觉。
瞬时速度
到目前为止,生成 NFE 最小的扩散模型,基本上都是 ODE,因为确定性模型往往更容易分析和求解。本文同样只关注 ODE 式扩散,所用框架是《生成扩散模型漫谈:构建 ODE 的一般步骤(下)》介绍的 ReFlow,它跟 Flow Matching [7] 本质是相通的,但更加直观。
ODE 式扩散,是希望学习一个 ODE
来构建一个 的变换。具体来说,设 是某个容易采样的随机噪声, 则是目标分布的真实样本,我们希望能够通过上述 ODE,实现随机噪声到目标样本的变换,即随机采样一个 作为初值,求解上述 ODE 得到的 就是 的样本。
如果将 看成时间, 看成位移,那么 就是瞬时速度,所以 ODE 式扩散就是瞬时速度的建模。那怎么训练 呢?ReFlow 提出了一种非常直观的方法:首先构建 与 的任意插值方式,如最简单的线性插值 ,那么对 t 求导得
这是个极简单的 ODE,但不符合我们的要求,因为 是我们的目标,它不应该出现在 ODE 中。对此,ReFlow 提出一个非常符合直觉的想法——用 去逼近 :
这就是 ReFlow 的目标函数。值得指出的是:1)ReFlow 理论上允许 与 的任意插值方式;2)ReFlow 虽然直观,但理论上也是严格的,可以证明它的最优解确实是我们所求的 ODE。相关细节大家请参考《生成扩散模型漫谈:构建 ODE 的一般步骤(下)》以及原论文。
平均速度
然而,ODE 仅仅是一个纯数学形式,实际求解还是需要离散化,比如最简单的欧拉格式:
从 1 到 0 的 NFE 是 ,想要 NFE 小等价于 大。然而,ReFlow 的理论基础是精确的 ODE,即精确求解 ODE 时才能实现目标样本的生成,这意味着 越小越好,跟我们的期望相背。
尽管 ReFlow 声称使用直线插值可以让 ODE 的轨迹变得更直,从而允许更大的 ,但实际轨迹终究是弯曲的, 很难接近1,所以 ReFlow 很难实现一步生成。
归根结底,ODE 本来就是 的东西,我们非要将它用于 ,还要求它效果好,这本身就是“强模型所难”了。所以说,更换建模目标,而不是继续“为难”模型,才是实现更快生成的本质思路。为此,我们考虑对式(1)两端进行积分
如果我们可以建模
那么就有 ,即理论上可以精准地实现一步生成,而不必求诸于近似关系。如果说 是 t 时刻的瞬时速度,那么很显然 是 时间段内的平均速度。
也就是说,为了加速生成甚至一步生成,我们的建模目标应该是平均速度,而不是 ODE 的瞬时速度。
恒等变换
当然,从瞬时速度到平均速度的转变并不难想,真正难的地方是如何给它构建损失函数。ReFlow 只告诉我们如何给瞬时速度构建损失函数,对平均速度的训练我们是一无所知。
接下来很自然的想法是“化未知为已知”,即以平均速度 来为出发点来构建瞬时速度 ,然后代入 ReFlow 的目标函数,这需要我们去推导两者之间的恒等变换。从 的定义我们得到
两边对 求导,得到
这便是 跟 的第一个恒等关系。有第一自然就有第二,第二个恒等关系由平均速度的定义得到:
说白了,无限小区间内的平均速度,就等于瞬时速度。
第一目标
根据 以及恒等式(9),我们可以将恒等式(8)的 换成 或者 ,前者是隐式关系,我们后面再谈,我们先看后者,此时有:
代入 ReFlow,我们得到可以用来训练 的第一个目标函数:
这是一个非常理想的结果,它满足我们对生成模型目标函数的所有期望:
1. 单个显式的最小化目标;
2. 没有 EMA、stop_gradient 等运算;
3. 理论上有保证(ReFlow)。
这些特性意味着,不管我们用什么优化算法,只要我们能找到上式的最小值点,那么它就是我们想要的平均速度模型,即理论上能够实现一步生成的生成模型。
换句话说,它具备了扩散模型的训练简单和理论保证,又能像 GAN 那样一步生成,还不用求神拜佛保佑模型别“想不开”而训崩。
JVP 运算
不过,对于部分读者来说,目标函数(11)的实现还是有点困难的,因为它涉及到普通用户比较少见的“雅可比向量积(Jacobian-Vector Product,JVP)”。具体来说,我们可以将目标函数内方括号部分写成:
即 的雅可比矩阵与给定向量 的乘法,结果是一个跟 大小一致的向量,这种运算就叫做 JVP,在 Jax、Torch 里边都有现成实现,比如 Jax 的参考代码是:
u = lambda xt, r, t: diffusion_model(weights, [xt, r, t])
urt, durt = jax.jvp(u, (xt, r, t), (u(xt, t, t), r * 0, t * 0 + 1))
其中 urt 就是 ,而 durt 就是对应的 JVP 结果,Torch 的用法也类似。了解 JVP 运算后,目标函数(11)的实现就基本上没有难度了。
第二目标
如果要说目标函数(11)的缺点,在笔者看来只有一个,那就是计算量相对偏大。这是因为它要进行两次不同的前向传播 和 ,然后 JVP 求了一次梯度,用基于梯度下降优化时还要再求一次梯度,所以它本质上要求二阶梯度,跟以往的 WGAN-GP 类似。
为了降低计算量,我们可以考虑给 JVP 部分加上 stop_gradient 运算():
这样就避免了对 JVP 再次求梯度(但依然需要两次前向传播)。实测结果显示,相比第一目标(11),上述目标在梯度优化器下训练速度能够快将近一倍,并且效果目测无损。
注意,这里的 stop_gradient 单纯是出于减少计算量的目的,实际优化方向依然是损失函数值越小越好,这跟 CM 系列模型尤其是 sCM 是不一样的,它们的损失函数只是具有等效梯度的等效损失,并不一定是越小越好,它们的 stop_gradient 往往是必须的,一旦去掉几乎可以肯定会训练崩溃。
第三目标
前面我们提到,处理恒等式(8)中的 的另一个方案是将其换成 ,这将导致
如果要从中解出 ,结果将是
这涉及到了非常庞大的矩阵求逆,因此并不现实。MeanFlow 给出了一个折中方案:既然 的回归目标是 ,那干脆把 换成 好了,于是目标函数变成
然而,此时的 既是回归目标,又出现在模型 的定义中,难免会有一种“标签泄漏”的感觉。为了避免这个问题,MeanFlow 采取的办法同样是给 JVP 部分加上 stop_gradient:
这就是 MeanFlow 最终所用的损失函数,这里我们称之为“第三目标”。相比第二目标(13),它少了一次前向传播 ,所以训练速度会更快一些。但此时“标签泄漏”的引入和 stop_gradient 的对策,使得第三目标的训练跟梯度优化器是耦合的,这就跟 CM 一样,多了一些说不清道不明的神秘感。
论文实验结果表明,加上 的目标(17)是能训出合理结果的,那如果去掉它呢?
笔者向作者请教过,他表明去掉 后,训练依然能收敛,能多步生成,但没有一步生成能力了。其实这也不难理解,因为 时不管有没有 ,目标函数都退化为 ReFlow:
也就是说 MeanFlow 总有 ReFlow 在背后“兜底”,因此怎样也不至于太差。而去掉 后,“标签泄漏”的负面影响加剧,因此就不如加上它了。
证明一下
我们能否像 ReFlow 一样,从理论上证明第三目标(17)的最优解确实是我们期望的平均速度模型呢?让我们尝试一下。首先我们回顾证明 ReFlow 的两个关键引理:
1、,即最小化 与 的平方误差,最优解是 的均值;
2、按照分布轨迹 将 变到 的 ODE 形式解是 。
其中,引理 1 的证明比较简单,直接对 求梯度得 ,令它等于零即可。
引理 2 的证明细节则需要看《生成扩散模型漫谈:构建ODE的一般步骤(下)》,其中 是需要先利用 消去 ,得到一个 的函数,然后对分布 求期望,结果是关于 的函数。
利用引理 1,我们可以证明 ReFlow 的目标函数(3)的理论最优解就是 ,结合引理 2 就得到 是我们所求的 ODE。
第三目标(17)的证明类似,由于里边有 ,对 求梯度并让它等于零的结果是:
所以在适当的边界条件下就有 ,即我们期望的平均速度模型。
这个过程的关键是 的引入避免了对 JVP 部分求梯度,从而简化了梯度表达式并得到了正确的结果。
如果去掉 的话,上式右端就要多乘一项 JVP 部分对 的雅可比矩阵,结果就是最后无法将 这一项分离出来,而引入 的数学意义便是为了解决此问题。
当然,笔者还是那句话, 的引入也使得整个模型的训练耦合了梯度优化器,多了一丝不清晰的感觉。此时梯度等于零的点,顶多算是一个驻点而非(局部)最小值点,所以稳定性也不明朗,这其实也是所有耦合 的模型的共性。
相关工作
非常有趣的是,我们之前介绍过的两篇加速生成的文章《生成扩散模型漫谈:中值定理加速 ODE 采样》和《生成扩散模型漫谈:将步长作为条件输入》[8],也都是以平均速度为核心的,并且思想上可以说是一脉相承的。尽管作者之间未必相互有联系,但他们的工作内容上确实给人一种承前启后的连贯感。
在中值定理篇,作者已经意识到了平均速度
的重要性,但他的做法是类比一维函数的积分中值定理,试图寻找 使得 等于平均速度。这本质上还是寻找高阶 Solver 的思想,但不再是 Training-Free,而是需要少量的蒸馏步骤,对 Solver 来说算是一个小突破。
而步长输入篇所提的 Shortcut 模型,则几乎已经触碰到了 MeanFlow,因为步长作为额外输入,跟 MeanFlow 的双时间参数 r,t 实质是等价的,不同的是它是直接以平均速度的性质作为额外的正则项来训练模型。用本文的记号,平均速度应该满足的性质是
其中 。所以 Shortcut 干脆以它来构建正则项
跟 ReFlow 的目标(18)混合训练,实际训练中 , 的引入在笔者看来主要也是为了节省计算量。Shortcut 模型其实比 MeanFlow 更直观,但由于没有恒等变换和 ReFlow 带来的严格理论支撑,使得它看上去更多是一个过渡期的经验产物。
一致模型
最后我们再来讨论一下一致性模型。由于 CM、sCM 珠玉在前,MeanFlow 的成功实际上也借鉴了它们的经验,尤其是给 JVP 加 的操作,这在原论文中也有提到。
当然,MeanFlow 作者之一何恺明老师本身也是操控梯度的大师(比如 SimSiam),所以 MeanFlow 的出现看起来是非常水到渠成的。
离散的 CM 我们在《生成扩散模型漫谈:分步理解一致性模型》[4] 仔细分析过,如果将其中 CM 的 EMA 算符换成 stop_gradient,求梯度并取 的极限,那么就得到了《Simplifying, Stabilizing and Scaling Continuous-Time Consistency Models》[5] 中的 sCM 的目标函数:
如果将 换成 ,然后记 ,那么它的梯度跟 时的 MeanFlow 第三目标(17)是等价的:
所以,从这个角度看,sCM 是 MeanFlow 在 r=0 时的一个特例。正如前面所说,引入另外的时间参数 r 可以 让 ReFlow 给 MeanFlow “兜底”(r=t 时),从而更好地避免训崩,这是它的优点之一。
当然,从 sCM 出发其实也可以引入双时间参数,得到跟第三目标完全相同的结果,但从个人的审美来看,CM、sCM 的物理意义终究不如 MeanFlow 平均速度的诠释直观。
此外,平均速度和 ReFlow 结合的出发点,还可以得到另外的第一目标(11)和第二目标(13),这对于像笔者这样的 stop_gradient 洁癖患者来说是非常舒适和漂亮的结果。
在笔者看来,从计算成本出发,我们是可以考虑给损失函数加上 stop_gradient,但推导的第一性原理和基本结果不应该跟 stop_gradient 耦合,否则意味着它跟优化器和动力学是强耦合的,这并不是一个本质结果应有的表现。
文章小结
本文以最近出来的 MeanFlow 为中心,讨论了“平均速度”视角下的扩散模型加速生成思路。
参考文献
[1] https://papers.cool/arxiv/2505.13447
[2] https://papers.cool/arxiv/2206.00927
[3] https://kexue.fm/archives/10085
[4] https://kexue.fm/archives/10633
[5] https://papers.cool/arxiv/2410.11081
[6] https://papers.cool/arxiv/2310.02279
[7] https://papers.cool/arxiv/2210.02747
[8] https://kexue.fm/archives/10617
更多阅读
#投 稿 通 道#
让你的文字被更多人看到
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。
PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析、科研心得或竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。
📝 稿件基本要求:
• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注
• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题
• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算
📬 投稿通道:
• 投稿邮箱:hr@paperweekly.site
• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者
• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿
△长按添加PaperWeekly小编
🔍
现在,在「知乎」也能找到我们了
进入知乎首页搜索「PaperWeekly」
点击「关注」订阅我们的专栏吧
·