掘金 人工智能 05月02日 18:13
深度学习基础理论:混合精度训练以及gradient-checkpoint原理
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

本文深入探讨了深度学习模型训练中混合精度技术的应用。通过对比单精度、半精度和混合精度训练,揭示了混合精度在提升训练速度、降低显存占用方面的优势。文章详细介绍了混合精度训练的实现方法,包括FP32权重主拷贝、损失缩放等关键技术,并提供了使用Apex和PyTorch原生AMP进行混合精度训练的实践指导。此外,还讨论了混合精度训练中可能遇到的问题及解决方案,为读者提供了全面的技术参考。

💡 混合精度训练结合了单精度(FP32)和半精度(FP16/BF16)的优势。单精度提供高精度,半精度则降低显存占用和加速计算。混合精度通常使用FP32存储模型权重,FP16/BF16进行前向和反向传播计算。

⚙️ 混合精度训练通过FP32权重主拷贝和损失缩放等技术解决半精度可能引发的数值问题。FP32权重主拷贝保留了权重的精确值,损失缩放则避免了梯度下溢。在反向传播前将loss增打2k2^k2k倍,反向传播再去除这个常数即可。

🛠️ Apex和PyTorch原生AMP是实现混合精度训练的常用工具。Apex提供了amp模块,通过设置opt_level实现不同程度的混合精度。PyTorch的GradScaler则用于动态调整损失缩放,确保训练的稳定性和准确性。

⚠️ 混合精度训练中需关注GPU的FP16支持和CPU利用率。支持Tensor Core的GPU(如2080Ti、Titan、Tesla等)通常表现更好。CPU利用率过高可能导致GPU利用率下降,影响训练速度。使用Apex框架时,需要注意溢出问题,并根据需要调整loss_scale参数。

更加好的排版见:www.big-yellow-j.top/posts/2025/…

不同精度训练

单精度训练single-precision)指的是用32位浮点数(FP32)表示所有的参数、激活值和梯度半精度训练half-precision)指的是用16位浮点数(FP16 或 BF16)表示数据。(FP16 是 IEEE 标准,BF16 是一种更适合 AI 计算的变种)混合精度训练mixed-precision)指的是同时使用 FP16/BF16 和 FP32,利用二者的优点。通常,模型权重和梯度使用 FP32,而激活值和中间计算使用 FP16/BF16

Image From: www.exxactcorp.com/blog/hpc/wh…

不同精度之间对比:

指标单精度(FP32)半精度(FP16/BF16)混合精度
精度较低(FP16),中(BF16)中高
显存占用较低
训练速度较慢
稳定性最佳稳定性低(FP16)稳定
适用场景小规模任务性能优先,大规模模型性能与稳定的平衡

混合精度训练arxiv.org/pdf/1710.03…

为什么不只用单精度训练(速度快/显存占用少)1、直接使用半精度(FP16)容易引发数值问题,如溢出(overflow)下溢(underflow):这里是因为单精度有效尾数(约10位尾数)较单精度要小得多,那么就会有一个问题因此在训练过程中,如果激活函数的梯度非常小,可能会因精度不足而被舍弃为零,导致梯度下溢。此外,当数值超过半精度的表示范围时,也会发生溢出问题。这些限制会使训练难以正常进行,导致模型无法收敛或性能下降;2、舍入误差(Rounding Error) 舍入误差指的是当梯度过小,小于当前区间内的最小间隔时,该次梯度更新可能会失败,用一张图清晰地表示:

Image: zhuanlan.zhihu.com/p/79887894总的来说就是:如果只用半精度会导致精度损失严重,因此就会提出用混合精度进行训练

解决上面用单精度造成的问题,在混合精度训练中论文提到的解决办法:

模型权重会同时维护两个版本:1、FP32权重(Master Copy):以32位浮点数表示,用于存储和更新权重的精确值。2、FP16权重(Working Copy):以16位浮点数表示,用于前向传播和反向传播的计算,减少显存占用并加速运算

这里就会有一个问题,反向传播过程中要计算梯度,如果(梯度用FP16)梯度很小,不也还是会出现溢出问题,作者后续提到LOSS SCALING可以解决这种问题。如果梯度很大也会导致溢出问题,梯度计算使用FP16,但在权重更新之前,梯度会转换为 FP32 精度进行累积和存储,从而避免因溢出导致的权重更新错误。另外之所以要用FP32对权重进行保存这是因为,作者研究发现更新 FP16 权重会导致 80% 的相对准确度损失。we match FP32 training results when updating anFP32 master copy of weights after FP16 forward and backward passes, while updating FP16 weightsresults in 80% relative accuracy loss

另外一方面,如果拷贝权重,不也等同于把显存的占用拉大了?参考知乎上描述显存占用上主要是中间过程值

下图展示了 SSD 模型在训练过程中,激活函数梯度的分布情况,容易发现部分梯度值如果用FP16容易导致最后的梯度值变为0,这样就会导致上面提到的溢出问题,那么论文里面的做法就是:在反向传播前将loss增打2k2^k倍,这样就会保证不发生下溢出(乘一个常数,后面再去除这个常数不影响结果),如何反向传播再去除这个常数即可。

git clone https://github.com/NVIDIA/apexcd apexpython3 setup.py install

分别用Apex和torch原生的ampMNIST数据集上进行测试(模型:1层卷积+池化+2层全连接层)

# Apexfrom apex import amp...model, optimizer = amp.initialize(model, optimizer, opt_level="O1", loss_scale="dynamic")...with amp.scale_loss(loss, optimizer) as scaled_loss:            scaled_loss.backward()# Ampfrom torch.cuda.amp import autocast, GradScaler...scaler = GradScaler()...scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()model = CVModel(args= ModelArgs).to(device)scaler = GradScaler()optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)for _ in range(20):    with autocast():        out = model(in_data)        loss = nn.CrossEntropyLoss()(out, labels)    scaler.scale(loss).backward()    scaler.step(optimizer)    scaler.update()    optimizer.zero_grad()

ApexAmp参数(nvidia.github.io/apex/amp.ht…

1、opt_level欧1而不是零1):

O0:纯FP32训练,可以作为accuracy的baseline;O1:混合精度训练(推荐使用),根据黑白名单自动决定使用FP16(GEMM, 卷积)还是FP32(Softmax)进行计算。O2:“几乎FP16”混合精度训练,不存在黑白名单,除了Batch norm,几乎都是用FP16计算。O3:纯FP16训练,很不稳定,但是可以作为speed的baseline;

2、loss_scale="dynamic"

损失值处理(LOSS SCALING)默认是动态(初始一个较大的值,检查到溢出就减小)

测试效果:

准确率变化上

在公开数据集(CIFAR10)上进行测试(模型为resnet50)测试使用的设备为4090

训练集上变化

RunSmoothedValueStepTime显存占用
scalar-CIFAR10/scalar-256-amp0.80260.93641116.99 min15508
scalar-CIFAR10/scalar-256-apex0.80930.93661116.51 min13166
scalar-CIFAR10/scalar-256-fp320.79460.94561122.27 min22818

测试集上变化

RunSmoothedValueStepTime显存占用
scalar-CIFAR10/scalar-256-amp0.73020.80311116.99 min15508
scalar-CIFAR10/scalar-256-apex0.73230.79561116.51 min13166
scalar-CIFAR10/scalar-256-fp320.72500.80921122.27 min22818

根据知乎:NicolasDreaming.O实验建议:

import torchif torch.cuda.is_available():    device = torch.device("cuda")    compute_capability = torch.cuda.get_device_capability(device)    print(f"Compute Capability: {compute_capability[0]}.{compute_capability[1]}")else:    print("CUDA is not available.")

结果7≥7说明支持

如果训练时候 CPU 大量被占用的话,会导致严重的减速。具体表现在:CPU被大量占用后,GPU-kernel的利用率下降明显。估计是因为混合精度加速有大量的cast操作需要CPU参与,如果CPU拖了后腿,则会导致GPU的利用率也下降。

因为在Apexamp默认使用的是dynamic可以改为1024或者2048

显存优化

gradient-checkpoint参考:www.big-yellow-j.top/posts/2025/…

参考

1、arxiv.org/pdf/1710.03…2、www.exxactcorp.com/blog/hpc/wh…3、zhuanlan.zhihu.com/p/798878944、zhuanlan.zhihu.com/p/842197775、nvidia.github.io/apex/amp.ht…

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

混合精度 深度学习 模型训练
相关文章