机器之心 06月23日 19:39
无损减少80%激活值内存,提升5倍训练序列长度,仅需两行代码
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

香港中文大学(深圳)和上海交通大学的研究团队提出StreamBP算法,旨在解决大语言模型(LLM)长序列训练中的内存瓶颈问题。该算法通过对链式法则的线性分解和分步计算,显著降低了训练过程中激活值的内存占用,达到梯度检查点(gradient checkpointing)的20%左右。实验结果表明,StreamBP在相同内存限制下,最大序列长度可达梯度检查点的2.8-5.5倍,并且在相同序列长度下,其速度与梯度检查点接近甚至更快,适用于多种LLM目标函数,代码已开源。

🧠 StreamBP通过线性分解和分步计算,有效降低了LLM训练中激活值的内存占用,特别是针对Transformer层和lmhead层。

💡 StreamBP的核心在于避免储存单层的完整激活值,而是将反向传播过程进行线性分解和序列化计算,从而降低内存消耗。

🚀 实验结果显示,StreamBP在相同内存限制下,最大序列长度是梯度检查点的2.8-5.5倍,在相同序列长度下,速度接近甚至更快。

⚙️ StreamBP适用于SFT、GRPO、PPO和DPO等多种LLM目标函数,并已开源,方便集成到现有训练代码中。

📈 StreamBP在不同目标函数下都显著提升了最大序列长度,允许使用更大的批处理大小以加速训练。


本文的第一作者罗琪竣、第二作者李梦琦为香港中文大学(深圳)计算机科学博士生,本文在上海交通大学赵磊老师、香港中文大学(深圳)李肖老师的指导下完成。


长序列训练对于模型的长序列推理等能力至关重要。随着序列长度增加,训练所需储存的激活值快速增加,占据训练的大部分内存。即便使用梯度检查点(gradient checkpointing)方法,激活值依然占据大量内存,限制训练所能使用的序列长度。


来自港中文(深圳)和上海交通大学的团队提出 StreamBP 算法。通过对链式法则进行线性分解和分步计算,StreamBP 将大语言模型训练所需的激活值内存(logits 和 layer activation)降低至梯度检查点(gradient checkpointing)的 20% 左右。



论文标题:StreamBP: Memory-Efficient Exact Backpropagation for Long Sequence Training of LLMs

论文https://arxiv.org/abs/2506.03077

代码https://github.com/Ledzy/StreamBP


在相同内存限制下,StreamBP 最大序列长度为梯度检查点的 2.8-5.5 倍。在相同序列长度下,StreamBP 的速度和梯度检查点接近甚至更快。StreamBP 适用于 SFT、GRPO、PPO 和 DPO 等常见 LLM 目标函数。代码已开源,可集成至现有训练代码。


激活值内存和梯度检查点
在反向传播(Backpropagation, BP)的过程中,计算模型梯度需要用到模型的中间输出(激活值)。举例来说,对于模型中的线性变换的梯度为,因而计算的梯度时需要储存相应的激活值。


对于模型中的任意函数变换  的梯度由以下链式法则计算:



其中 L 为目标函数,为 Jacobian 矩阵。为了计算以上 Jacobian-vector product,需要在模型 forward 时储存函数变换的中间值(激活值),其内存消耗与 batch size、序列长度以及中间值维度正相关。


为了减少激活值的内存消耗,梯度检查点(gradient checkpointing)方法在 forward 时只储存每一层网络的输入,而不储存该层的中间值。在 backward 至该层时,将重新 forward 此层输入来计算得到该层激活值。使用梯度检查点时储存的激活值包括:


所有层的输入,一般为激活值内存的 5%-15%。单层的完整激活值,占据超过 85% 的激活值内存。
StreamBP 的核心思想
不同于梯度检查点,StreamBP 避免储存单层的完整激活值,而将单层的 BP 过程进行线性分解,序列化计算并累加。注意到对于函数变换,链式法则存在以下线性分解:


StreamBP 基于以下观察:对于 LLM 中的大部分函数变换,如 Transformer 层、lmhead 层,可通过策略性地将输出分块,使得计算块 Jacobian-vector product 所需的激活值远小于计算完整的 Jacobian-vector product。基于该观察,StreamBP 依次计算上式中 D 个块的 Jacobian-vector product 并累加,得到准确的梯度。


为了计算块 Jacobian-vector product,需要分析输入和输出的相关性,每次 forward 块输入 得到块输出,建立对应子计算图。以简单的线性变换 为例,输出和输入在行维度上一一对应。StreamBP 按行分块,每次计算单行的 Jacobian-vector product 并累加。下图对比了标准 BP 和 StreamBP 在上述线性变换下的实现:



D 步累加得到的和即为和准确梯度。相比于标准 BP,StreamBP 仅需储存和,且总计算 FLOPs 相同。下表为 StreamBP 和标准 BP 的内存和时间对比:



LLM 训练中的 StreamBP
StreamBP 应用于 LLM 中的 Transformer 层和 lmhead 层,分别用于降低层激活值和 logits 的内存消耗。


与线性变换不同,由于 Transformer 层存在注意力机制,块输出并非仅由对应位置的块输入决定,而与该块及以前所有位置的输入都有关。StreamBP 利用只与块有关的性质,建立了如下计算图:



StreamBP 所需储存的激活值和注意力掩码(橙色)大幅低于梯度检查点(橙色 + 白色部分)。


对于 lmhead 层,当以 SFT 或 GRPO 为目标函数时,观察到不同位置的 logits 对于目标函数的影响相互独立。因此,StreamBP 从序列维度分块,每次计算单块损失函数的梯度,从而只需储存单块 logits 和 logits 梯度。


图:StreamBP for SFT


图:StreamBP for GRPO


对于 DPO,由于非线性 sigmoid 函数的存在,每个位置的 logits 对于目标函数的影响并不独立。StreamBP 利用 logits 梯度在序列维度的独立性,分块进行梯度计算。


图:StreamBP for DPO


实验结果
我们在单张 A800-80GB GPU 上测试了不同大小的模型,StreamBP 的最大 BP 序列长度为标准 BP 的 23-36 倍,梯度检查点的 2.5-5.5 倍。


图:不同序列长度下的 BP 峰值内存


在现有 Transformers 框架下,StreamBP 的实现可避免计算掩码部分的 pre-attention score(见论文 3.2.2 部分),在长序列训练下相较于梯度检查点实现了加速。



通过使用 StreamBP,不同目标函数下最大的序列长度得到了大幅提升。在同样的序列长度下,StreamBP 允许更大的批处理大小以加速训练。


表:Qwen 3-4B 单个样本 BP 时间,序列长度为 9000。


在 Deepspeed ZeRO 分布式训练模式下,Distributed StreamBP 比梯度检查点的最大可训练序列长度提升了5—5.6倍。



© THE END 

转载请联系本公众号获得授权

投稿或寻求报道:liyazhou@jiqizhixin.com


今天看啥地址:http://www.jintiankansha.me/t/DsO9NBIBv9

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

StreamBP 大语言模型 内存优化 长序列训练
相关文章