掘金 人工智能 05月09日 10:09
DeepSpeed 微调 LLaMA-2完整步骤
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

本文提供了使用DeepSpeed微调LLaMA-2模型的完整步骤,从环境配置、数据准备到训练脚本编写和部署优化,覆盖了单机多卡和多机分布式场景。内容包括硬件和软件依赖的准备,数据集的JSON格式要求,DeepSpeed配置文件的关键参数设置,以及微调脚本的编写。此外,还介绍了如何启动训练、监控显存、解读训练日志,以及模型导出和部署的方法。最后,针对CUDA内存不足、多机通信失败和微调效果不佳等常见问题,给出了详细的解决方案。

🛠️环境准备:包括硬件配置(至少2张A100 80GB GPU)和软件依赖(torch, transformers, datasets, accelerate, deepspeed, sentencepiece)的安装,以及通过Meta官方许可获取LLaMA-2访问权限并下载模型权重。

📚数据准备:数据集采用JSON格式,包含instruction和output字段,并提供数据加载脚本示例。数据集需要保存为train.jsonl和eval.jsonl文件。

⚙️DeepSpeed配置:通过ds_config.json文件配置DeepSpeed,关键参数包括zero_optimization.stage(启用ZeRO-3优化,最小化显存占用)和offload_optimizer(将优化器状态卸载到CPU,节省GPU显存)。同时,需要根据GPU显存调整train_micro_batch_size_per_gpu。

🚀微调脚本:使用transformers和trl库编写微调脚本train.py,加载模型和分词器,定义训练参数(如output_dir, per_device_train_batch_size, num_train_epochs),并使用SFTTrainer创建Trainer对象,最后启动训练。

💡监控与优化:训练过程中,使用nvidia-smi监控显存占用,关注loss下降曲线和grad_norm。如果出现OOM,可以减小per_device_train_batch_size或增加gradient_accumulation_steps。

以下是使用 DeepSpeed 微调 LLaMA-2 的完整步骤,涵盖环境配置、数据处理、训练脚本编写和部署优化,适用于单机多卡或多机分布式场景:


1. 环境准备

(1) 硬件要求

(2) 软件依赖

# 创建conda环境conda create -n llama2 python=3.10 -yconda activate llama2# 安装核心库pip install torch==2.0.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117pip install transformers==4.31.0 datasets accelerate deepspeed sentencepiece

(3) 获取LLaMA-2访问权限

    申请Meta官方许可:ai.meta.com/resources/m…下载模型权重(选择7B/13B/70B版本):
    huggingface-cli login  # 使用HF账号认证huggingface-cli download meta-llama/Llama-2-7b-hf --local-dir ./llama-2-7b

2. 数据准备

(1) 数据集格式

(2) 数据加载脚本

from datasets import load_datasetdataset = load_dataset("json", data_files={"train": "train.jsonl", "eval": "eval.jsonl"})

3. 配置DeepSpeed

(1) 创建DeepSpeed配置文件ds_config.json

{  "train_micro_batch_size_per_gpu": 4,  "gradient_accumulation_steps": 8,  "optimizer": {    "type": "AdamW",    "params": {      "lr": 2e-5,      "weight_decay": 0.01    }  },  "fp16": {    "enabled": true,    "loss_scale_window": 100  },  "zero_optimization": {    "stage": 3,    "offload_optimizer": {      "device": "cpu"    },    "allgather_partitions": true,    "allgather_bucket_size": 5e8  },  "steps_per_print": 50}

(2) 关键参数说明


4. 微调脚本train.py

import torchfrom transformers import (    LlamaForCausalLM,    LlamaTokenizer,    TrainingArguments)from trl import SFTTrainer# 1. 加载模型和分词器model = LlamaForCausalLM.from_pretrained(    "./llama-2-7b",    torch_dtype=torch.float16,    device_map="auto")tokenizer = LlamaTokenizer.from_pretrained("./llama-2-7b")tokenizer.pad_token = tokenizer.eos_token# 2. 定义训练参数training_args = TrainingArguments(    output_dir="./output",    per_device_train_batch_size=4,    num_train_epochs=3,    logging_steps=10,    save_steps=500,    fp16=True,    deepspeed="./ds_config.json",  # 关键:启用DeepSpeed)# 3. 创建Trainertrainer = SFTTrainer(    model=model,    args=training_args,    train_dataset=dataset["train"],    dataset_text_field="instruction",  # 指定文本字段    max_seq_length=1024,    tokenizer=tokenizer)# 4. 开始训练trainer.train()

5. 启动训练

(1) 单机多卡(4卡为例)

deepspeed --num_gpus=4 train.py

(2) 多机训练

# 主机(rank=0)deepspeed --hostfile=hostfile --master_addr=主节点IP --master_port=29500 train.py# hostfile内容示例node1 slots=4  # 第一台机器4卡node2 slots=4  # 第二台机器4卡

6. 监控与优化

(1) 显存监控

watch -n 1 nvidia-smi  # 实时查看GPU利用率

(2) 训练日志解读


7. 模型导出与部署

(1) 保存微调后的模型

trainer.save_model("./llama-2-7b-finetuned")

(2) 转换为推理优化格式

# 导出为vLLM兼容格式(如需高性能API)python -m vllm.entrypoints.convert_model ./llama-2-7b-finetuned --output-dir ./llama-2-7b-vllm

(3) 启动推理服务

python -m vllm.entrypoints.api_server --model ./llama-2-7b-vllm --tensor-parallel-size 4

常见问题解决

    CUDA内存不足

      ds_config.json中启用"offload_param": {"device": "cpu"}使用梯度检查点:model.gradient_checkpointing_enable()

    多机通信失败

      检查防火墙设置sudo ufw allow 29500确保所有节点SSH免密登录

    微调效果不佳

      尝试调整学习率(1e-5到5e-5)增加数据量或使用数据增强

性能参考(7B模型)

硬件Batch Size显存占用Tokens/sec
1x A100 80GB472GB1200
4x A100 80GB164×42GB4800

如果需要针对特定场景(如医疗文本微调)的完整代码库,可以参考:

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

DeepSpeed LLaMA-2 微调 分布式训练 AI模型
相关文章