在 PyTorch 训练中,torch.cuda.empty_cache()
的正确使用需要结合具体场景。以下是 5 种典型使用场景和最佳实践建议:
1. 训练循环中显存碎片化严重时适用场景• 当出现 CUDA out of memory
但计算显存需求理论上足够时
• 使用动态计算图或频繁创建/释放临时张量
推荐位置
for epoch in range(epochs): for batch in dataloader: # 前向/反向计算... optimizer.step() if batch_idx % 100 == 0: # 每100个batch清理一次 torch.cuda.empty_cache() # 清理未使用的缓存
注意事项• 过于频繁调用(如每个batch)会降低性能(约5-10%速度损失)
• 建议配合 torch.cuda.memory_summary()
诊断碎片化程度
2. 大模型分阶段加载时适用场景• 使用梯度检查点 (Gradient Checkpointing)
• 模型太大需要分段加载
示例代码
def forward_segment(segment, x): torch.cuda.empty_cache() # 加载新段前清理 segment = segment.to(device) return segment(x)for segment in large_model.segments: output = checkpoint(forward_segment, segment, inputs)
3. 多任务交替执行时适用场景• 同一个脚本交替执行训练和评估
• 不同模型共享GPU资源
代码示例
# 训练阶段train(model_A)torch.cuda.empty_cache() # 训练后立即清理# 评估阶段evaluate(model_B) # 确保model_B能获得足够显存
4. 数据预处理与训练混合时适用场景• 使用GPU加速数据增强
• 动态生成训练数据
推荐写法
for epoch in epochs: # GPU数据增强 augmented_batch = gpu_augment(batch) # 训练主模型 train_step(model, augmented_batch) # 清理增强操作的中间缓存 del augmented_batch torch.cuda.empty_cache()
5. 异常恢复后适用场景• 捕获 CUDA OOM
异常后尝试恢复
• 测试最大可用batch size时
代码实现
try: large_batch = next(oversized_loader) output = model(large_batch)except RuntimeError as e: if "CUDA out of memory" in str(e): torch.cuda.empty_cache() # 尝试释放残留显存 reduced_batch = large_batch[:half_size] # 重试...
最佳实践总结
场景 | 调用频率 | 是否必需 | 典型性能影响 |
---|---|---|---|
常规训练 | 每N个batch | ❌ 可选 | <5% 减速 |
大模型加载 | 每次分段前 | ✔️ 必需 | 可避免OOM |
多任务切换 | 任务边界 | ✔️ 推荐 | 可复用显存 |
异常恢复 | 按需 | ✔️ 关键 | 恢复成功率+50% |
调试阶段 | 任意位置 | ❌ 避免 | 干扰内存分析 |
高级技巧
与内存分析工具配合:
print(torch.cuda.memory_summary()) # 清理前torch.cuda.empty_cache()print(torch.cuda.memory_summary()) # 清理后
PyTorch Lightning 集成:
class MyModel(LightningModule): def on_train_batch_end(self): if self.current_epoch % 10 == 0: torch.cuda.empty_cache()
显存碎片化监控:
def check_fragmentation(): allocated = torch.cuda.memory_allocated() reserved = torch.cuda.memory_reserved() if reserved - allocated > 1e9: # 碎片>1GB torch.cuda.empty_cache()
何时应该避免调用
- 在关键性能路径上:如高频调用的损失函数内使用
torch.no_grad()
块时:此时无梯度缓存需要清理确定无显存泄漏时:过度调用会导致不必要的同步点合理使用此方法可将GPU利用率提升15-30%(特别是在大模型训练中),但需要结合具体场景权衡性能与显存占用的平衡。