掘金 人工智能 12小时前
如何使用LoRA通过微调增强大模型
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

本文分享了LLM(大型语言模型)微调的实践经验,涵盖环境配置、数据集准备、模型选择和训练过程。作者以Meta-Llama-3-8B-Instruct模型为例,使用SFT(监督式微调)方法,在甄嬛传数据集上进行微调,最终完成了一个基础的LLM微调项目。

⚙️ 环境搭建:文章首先介绍了微调所需的关键环境,包括CUDA、PyTorch版本以及必要的依赖包,并提供了pip安装命令和更换镜像源的建议,方便读者快速配置环境。

📚 数据准备:文章详细介绍了SFT和DPO两种微调方式所需的数据集结构,并选择了甄嬛传数据集。同时,文章也给出了SFT数据集的具体结构,包括instruction、input和output,为后续的数据处理奠定了基础。

💻 训练代码:文章提供了完整的训练代码,包括模型加载、数据处理、LoRA配置、训练参数设置和训练启动。其中,数据处理函数是关键,它将原始数据转换为模型可接受的输入格式。LoRA配置则用于高效地进行模型微调。

⏱️ 训练过程:文章分享了训练过程中的经验,包括模型选择、数据集的读取和预处理、LoRA参数的配置,以及训练参数的设置。作者还强调了模型下载的耗时问题,并最终完成了模型的训练。

一直以来更加把精力放在AI的应用上,但是随着时间的推移发现模型微调也是AI应用一个无法迈过的坎,学习的路还很长,写文章也是督促自己继续下去的动力

环境安装

首先我们要做的就是微调模型的环境安装,这里我列出一下我使用的版本

CUDA版本

pytorch版本

依赖包

pip install modelscope==1.16.1pip install transformers==4.43.1pip install accelerate==0.32.1pip install peft==0.11.1pip install datasets==2.20.0

如果发现下载缓慢可以更换pip的仓库源

pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple

数据准备

这里我们可以从魔塔中寻找自己感兴趣的训数据集

李火旺的数据集一下就进入我是视野,这里有意思的数据集好像也不多

寻找数据集时需要注意找我们的训练方式可以直接使用的数据集,我这边使用的训练方式是SFT,火子哥这个数据集是DPO方式训练的,我不能直接使用,但是找到了另一个甄嬛的数据集,也是挺有意思的

数据集的数据结构

这里只介绍两种数据集的结构,下面这种我们决定使用的数据集就是甄嬛传的数据集

SFT(监督式微调,Supervised Fine-Tuning)

    instruction 提示语/用户输入/上下文input 可选补充信息,如果为空表示没有附加输入output 模型应该生成的答案(监督信号)

DPO 训练是一种“偏好对比学习”

    prompt 用户输入,给模型看的提示语chosen 一个好的(偏好更高的)模型回答rejected 一个差的、被拒绝的回答

这是chatgpt给出的区别解释

项目SFT(有监督微调)DPO(直接偏好优化)
🔧 训练目标学会输出“标准答案”学会偏好“更好”的回答
📊 数据需求单条 input + output多条 input + (chosen, rejected) 对比数据
📦 数据难度易收集(只需一个好答案)难收集(要构造两个回答并打分优劣)
🧠 模型训练方式监督学习近似偏好学习(强化学习的简化版)
💬 学到的能力更加规范、模板化更符合人类偏好,更自然、人性化
🚧 训练稳定性高,收敛快依赖对比样本质量,调参更敏感
⚠️ 错误传播风险输出错误直接作为学习目标,会放大数据缺陷错误样本影响小,因训练的是倾向
✅ 模型偏好优化强,适用于对齐阶段(如RLHF)
🤖 生成多样性低,偏模板化高,鼓励多样表达

训练代码

引入必要依赖

from datasets import Datasetimport pandas as pdimport torchfrom transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer, GenerationConfigfrom peft import LoraConfig, TaskType, get_peft_model
    datasets.Dataset 用于构造训练数据集pandas 用于读取和处理 json 数据transformers 加载预训练模型、分词器、训练参数等peft 进行 LoRA 低秩适配训练

数据处理函数

def process_func(example):    MAX_LENGTH = 384    # Llama分词器会将一个中文字切分为多个token,因此需要放开一些最大长度,保证数据的完整性    input_ids, attention_mask, labels = [], [], []    instruction = tokenizer(f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n现在你要扮演皇帝身边的女人--甄嬛<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{example['instruction'] + example['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", add_special_tokens=False)  # add_special_tokens 不在开头加 special_tokens    response = tokenizer(f"{example['output']}<|eot_id|>", add_special_tokens=False)    input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id]    attention_mask = instruction["attention_mask"] + response["attention_mask"] + [1]  # 因为eos token咱们也是要关注的所以 补充为1    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id]    if len(input_ids) > MAX_LENGTH:  # 做一个截断        input_ids = input_ids[:MAX_LENGTH]        attention_mask = attention_mask[:MAX_LENGTH]        labels = labels[:MAX_LENGTH]    return {        "input_ids": input_ids,        "attention_mask": attention_mask,        "labels": labels    }

训练代码

if __name__ == "__main__":    model = AutoModelForCausalLM.from_pretrained('./LLM-Research/Meta-Llama-3___1-8B-Instruct', device_map="auto",torch_dtype=torch.bfloat16)    model.enable_input_require_grads() # 开启梯度检查点时,要执行该方法    tokenizer = AutoTokenizer.from_pretrained('./LLM-Research/Meta-Llama-3___1-8B-Instruct', use_fast=False, trust_remote_code=True)    tokenizer.pad_token = tokenizer.eos_token    # 将JSON文件转换为CSV文件    df = pd.read_json('huanhuan.json')    ds = Dataset.from_pandas(df)    tokenized_id = ds.map(process_func, remove_columns=ds.column_names)    config = LoraConfig(        task_type=TaskType.CAUSAL_LM,        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],        inference_mode=False, # 训练模式        r=8, # Lora 秩        lora_alpha=32, # Lora alaph,具体作用参见 Lora 原理        lora_dropout=0.1# Dropout 比例    )    model = get_peft_model(model, config)    model.print_trainable_parameters() # 打印总训练参数    args = TrainingArguments(        output_dir="./output/llama3_1_instruct_lora",        per_device_train_batch_size=4,        gradient_accumulation_steps=4,        logging_steps=10,        num_train_epochs=3,        save_steps=100,        learning_rate=1e-4,        save_on_each_node=True,        gradient_checkpointing=True    )    trainer = Trainer(        model=model,        args=args,        train_dataset=tokenized_id,        data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),    )    trainer.train() # 开始训练
训练代码的补充解释
加载模型和分词器
model = AutoModelForCausalLM.from_pretrained(...)tokenizer = AutoTokenizer.from_pretrained(...)
读取数据集并转换格式
df = pd.read_json('huanhuan.json')ds = Dataset.from_pandas(df)tokenized_id = ds.map(process_func, remove_columns=ds.column_names)
LORA配置并应用到模型
config = LoraConfig(...)model = get_peft_model(model, config)
训练参数配置
args = TrainingArguments(...)
创建 Trainer 并开始训练
trainer = Trainer(...)trainer.train()

模型选择

模型我们选择一个8B的小模型,也是在modelscope下载就可以

Meta-Llama-3-8B-Instruct

可以使用下面的代码下载

import torchfrom modelscope import snapshot_download, AutoModel, AutoTokenizerimport osmodel_dir = snapshot_download('LLM-Research/Meta-Llama-3-8B-Instruct', cache_dir='/root/autodl-tmp', revision='master')

下载时间有点长,可能要耐心等待一下

耗时2个小时左右,漫长的等待,都是钱啊

开始训练

我设置了进行3轮训练

整个流程耗时20分钟左右,最大的耗时还是模型下载

微调后对话

就写到这里,后续有什么新的内容再分享吧,微调学习才刚起步还需要多多实践,还有数据集制作等很多内容要学习

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

LLM 微调 SFT Meta-Llama-3
相关文章