掘金 人工智能 06月22日 17:02
极客说|强化学习(RL)与有监督微调(SFT)的选择以及奖励函数的优化
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

本文深入探讨了强化学习(RL)和监督微调(SFT)在AI模型训练中的差异,并通过具体案例详细阐述了奖励函数的优化策略。文章首先对比了SFT与RL的实现方式,然后重点分析了“先SFT后RL”的训练策略优势,以及RL中常见问题的解决方案。最后,通过实例分析,阐述了如何设计和优化奖励函数,以提升模型性能。

🤖 监督微调(SFT)类似于“老师教学”,通过提供问题和标准答案,让模型模仿学习,其优点是安全稳定,但可能限制模型的创造性。

💡 强化学习(RL)则采用“鼓励与惩罚”机制,引导模型自主探索和优化策略,其优势在于模型能主动探索,但可能面临KL爆冲、梯度爆炸等风险。

✅ “先SFT后RL”的训练策略更具优势,SFT能提升训练稳定性、提高数据利用效率,并降低人工标注成本,RL则用于进一步优化,尤其适用于需要结构化输出的场景。

💥 RL训练中,KL爆冲、梯度爆炸和模型崩盘是常见问题,这些问题通常源于奖励函数设计不当或超参设置错误,需要通过调整KL惩罚系数、优化奖励函数等方式解决。

📊 奖励函数的设计至关重要,如通过细化奖励分值、增加群组投票等方式,可以提升模型在RL阶段的性能,例如,在答案正确性奖励中,可以设置完全正确、差1分和错误等不同级别的奖励,以引导模型更准确地输出答案。

「极客说」 是一档专注 AI 时代开发者分享的专栏,我们邀请来自微软以及技术社区专家,带来最前沿的技术干货与实践经验。在这里,您将看到深度教程、最佳实践和创新解决方案。关注「极客说」,与行业顶尖专家一起探索科技的无限可能!投稿请联系:17278094563(微信号)

本文首先将阐述强化学习(RL)和监督微调(SFT)在实现方式上的区别,然后通过一个具体案例,详细说明如何对奖励函数进行优化。

从简单例子入手理解 SFT 和 RL

监督微调(SFT)- 像老师教学生

监督微调(Supervised Fine-Tuning,简称 SFT)相当于作为老师,自己先列出很多问题,再告诉模型标准的回答,比如用数据(训练集)教它:

我们让模型一遍又一遍模仿训练语料中的标准答案,直到我们符合要求。

SFT 具体步骤(算法的介绍)

    我们拿出一个问题:苹果什么颜色?模型自己尝试回答:比如它乱回答成 蓝色。我们就立马纠正,告诉它正确的答案应该是红色,给它一个明确的误差信号: [ 误差 = - log P("红色") ]然后模型用这个误差信号帮助它更新自己说法,让下次“红色”概率增加。

所以,监督学习过程如下:

for 问题, 标准答案 in 数据集:    模型答案 = 模型生成(问题)    误差 = 计算交叉熵Loss(模型答案, 标准答案)    模型更新(误差)

优点:安全、稳定

缺点:模型永远只能模仿,不太能创造性地发现新答案。

强化学习(RL)– 让模型自己摸索

强化不直接教标准答案,而是用“鼓励”和“惩罚”引导模型。

我们问模型:“1加1等于?”

模型得到这些奖励和惩罚之后,会慢慢去摸索和记忆,知道怎么才能得到更多奖励(而不是直接告诉它标准答案)。

强化学习大致算法:

# RL过程:for 问题 in 数据集:    # 让鹦鹉自由生成多个答案(探索)    多个答案 = 模型生成多个可行答案(问题)     # 每个答案给奖励    for 每个答案 in 多个答案:        奖励 = 奖励函数(每个答案)        更新策略(奖励 * log(生成该答案概率))

优势:模型能够自己发现最优策略,能主动“探索”,学得更主动;

危险:但探索过猛容易产生 KL 爆冲、梯度爆炸、最终模型崩盘。

SFT 和 RL 选择

大多数情况下训练模型先 SFT 再 RL 更安全、更高效,尤其是对能力尚弱的小模型或需要严格格式输出的任务。不过这并不是绝对法则,下面补充几点可作为快速校验的要点。

为什么“先 SFT 后 RL”通常更好

训练稳定性

数据利用效率

人工标注成本

直接 RL 的合理场景

    几乎没有标注数据、但可以自动计算奖励,例如:解数独、玩 Atari 游戏,环境本身给出分数。大模型已具备强基础能力 GPT-4、Claude 3-Sonnet 这一级别,格式和基本推理已比较稳,直接 RL(或 RLAIF)效果也可接受。任务鼓励高多样性、无法提供单一“标准答案” 如创意写作、对话风格优化,仅用偏好打分即可训练。

实践经验速查表

    我们的奖励函数是不是完全依赖“答案==标准答案”? 如果是,说明我们已经有明确标注;SFT 通常先做更划算。我们有多大 GPU/TPU 预算? RL(尤其 GRPO/PPO)往往需要比 SFT 高 2-4 倍的算力。任务对“推理链”可解释性要求高吗? 先 SFT(教会标签格式)再 RL(提升正确率)更容易满足可解释输出。

结论

“先 SFT 再 RL”并非硬性规定,但在绝大多数需要结构化输出、且有可用标注的场景下是最省心、最稳妥的路径。只有当标注极少或任务天然提供可计算奖励时,才会优先考虑“直接 RL”。

RL 常见问题

前文提到的 RL 常见的 KL 爆冲、梯度爆炸、模型崩盘问题,本小节详细介绍。

一般情况下,这三个问题会组成一条「连锁反应」:

奖励函数设计不佳或超参错误      ↓↓导致↓↓   KL爆冲 --> 梯度爆炸 --> 模型参数剧烈变化或NaN      ↓↓进一步导致↓↓   模型崩盘 (输出单一、低质)

KL 爆冲

KL 散度(Kullback–Leibler Divergence)本质上衡量的确实是两个概率分布之间的差距。在 DPO(Direct Preference Optimization)方法中,参考模型(reference model)和 训练中模型(policy model)之间计算的就是 KL 散度

用简单例子解释一下:

假设默认模型只会讲三句话:“我们好”、“谢谢”、“再见”。

它现在的“说话概率”(也可以叫“原始概率分布”)是:

我们心目中理想的“模型应该说话的概率分布”(目标概率分布)是:

我们希望模型朝着目标概率(Q 分布)学习,但它原本的习惯是当前概率(P 分布)。

这时候,为了知道我们的鹦鹉目前的概率分布 P目标概率分布 Q 差距有多远。

在例子中,如果原来模型会说:“我们好(Hello)”,但我们想教它说:"谢谢(Thank you)",那么就有了:

假设我们给了模型过分高的奖励,比如只要提到“谢谢”,我们奖励20分。模型会在几步内学得太猛,突然所有问题只回复:“谢谢谢谢!”这就是 KL 距离瞬间爆发。

KL 爆冲发生以后,需要用算法调整 KL 惩罚系数(β)

Loss 总 = 奖励损失 + β × KL 散度

提高 β,比如 0.01 → 0.1,约束模型变化的幅度。

梯度爆炸

深度学习中很常见的梯度爆炸问题主要是指:

最常见导致梯度爆炸的情况,很少是简单的代码 Bug;事实上更多是算法超参设置不当或数值计算不稳定导致的:

未使用梯度裁剪或裁剪设置值过大:如果训练过程中未用梯度裁剪方法,或梯度裁剪的上限值设置过大(如10以上),一旦梯度猛增就不能约束,即可引发梯度爆炸。

算法表现为梯度值剧烈变大甚至 NaN。

模型崩盘

模型崩盘的本质含义是:

模型崩盘有典型的指标,例如:

算法上,熵的定义是:

熵值 = -sum( p(X_i)*log(p(X_i)) )# 熵越低,表示模型生成的语言越单调单一,越接近崩盘

一种典型的模型崩盘的表现是:

模型崩盘最常见的直接原因是源于强化学习训练过程本身的一系列内在问题(尤其是强化学习),例如:

如果出现上述问题我们还继续训练,鹦鹉最后脑袋就真的弄坏了。比如它彻底只会一招,一问就吐出“苹果苹果”或彻底傻掉不回话,再训练也没用(模型崩溃)。

SFT 与 GRPO 的两阶段训练

接下来,参考 repo 中 code 目录下的训练代码,我们详细介绍 SFT 和 GRPO 的区别。

说明

要在 Hugging Face Hub 搜索 “GAIR/LIMO” 和 “openai/gsm8k” 即可查看与下载完整数据。

两阶段各自“训练了什么”?

阶段①:SFT(Supervised Fine-Tuning)

训练信号

学到内容

不学/很少学到

阶段②:GRPO(Reinforcement Learning, KL-regularized

训练信号

学到内容

不再关注

在两阶段训练中:

两阶段原始数据集字段如下:

训练脚本脚本 map 之后变成如下格式:

这样便于核对:

SFT 中的训练格式解释

如上一段内容解释,SFT map 后的训练语料并没有 completion 字段。

在 SFTConfig 训练代码里设置了

completion_only_loss=False

这表示“不要只对 completion 计算损失,而是对整条 prompt 进行 teacher-forcing”。在这种模式下,SFTTrainer 并不需要单独的 completion 字段——只要有一列 prompt 含完整参考答案即可。

    但 SFTTrainer 源码要求数据集中必须存在 completion 这一列(无论用不用)。为了省事就补了空 字符串占位,使得字段齐全、代码不报错。为什么不把 answer 放进 completion? 如果我们设 completion_only_loss=True,那就需要把 <answer>25</answer> 部分挪到 completion,让 prompt 只包含系统提示 + question + <reasoning>…</reasoning>。 当前脚本选用整串 CE 方式,所以 completion 留空即可。

简而言之:

SFT 训练损失函数的构建

三种构建 SFT 损失函数方案

把 COT + answer 全放到 completion(方案 C)会发生什么?

1、prompt 只剩 “系统提示 + 题干”,长度变短 → 同批显存更低;2、model 在训练时只要“读题干 → 预测 reasoning+answer”, 形成经典的 Instruction → Target 教师强制结构;3、优势

4、可能副作用

如何选择?

在我们当前“小数据、 1 epoch”的设置下,整串 CE 提供最稠密梯度;如果未来扩充 LIMO 到数万条并跑多 epoch,可以考虑方案 C,并在 RL 阶段继续用格式奖励守护模板,以获得更高数值准确率且不过拟合冗长 COT。

如果改成方案 C——把整个

<reasoning>……</reasoning><answer>……</answer>

都放进 completion,只让交叉熵监督这段文本,但增加一个格式类奖励仍然是最稳妥的做法。理由与操作要点如下。

设计欠佳奖励函数(优化前的奖励函数)

在强化学习训练中,答案正确性的判断通常通过自动化脚本实现,而非依赖人工标注的表格。以下是具体实现逻辑。

格式奖励函数(format_reward_func

目标

确保模型输出符合预设的 XML 标签结构 <reasoning>...</reasoning><answer>...</answer>

代码实现

import redef format_reward_func(completions, **kwargs):    """检查输出是否符合XML标签格式"""    pattern = r"^<reasoning>[\s\S]*?<\/reasoning>\s*<answer>[\s\S]*?<\/answer>$"    responses = [completion[0]["content"] for completion in completions]    rewards = [1.0if re.match(pattern, response) else0.0for response in responses]    return rewards

逻辑解析

正确性奖励函数(correctness_reward_func

目标

验证模型输出的数值答案是否与标准答案一致。

代码实现

def correctness_reward_func(completions, answer, **kwargs):    """检查答案是否正确"""    responses = [completion[0]["content"] for completion in completions]    extracted_responses = [extract_last_xml_answer(response) for response in responses]    rewards = [        2.0if extracted == correct else0.0        for extracted, correct inzip(extracted_responses, answer)    ]    return rewards

依赖函数 extract_last_xml_answer

def extract_last_xml_answer(response):    """从XML标签中提取答案(若格式错误,则取最后一个数字)"""    try:        # 尝试解析XML标签        answer = re.search(r"<answer>(.*?)</answer>", response).group(1).strip()        return answer    except:        # 格式错误时,提取最后一个数字        numbers = re.findall(r"\d+\.?\d*", response)        return numbers[-1] if numbers else""

逻辑解析

总奖励计算

关键设计考量

奖励函数优化思路

奖励函数优化主要包含:

    细化奖励分值增加群组投票

数字奖励 cor_reward (0 / 1 / 2 分)

XML_RE  = re.compile(r"<answer>(.*?)</answer>", re.S)_num    = lambda x: re.sub(r"[%$,]", "", x).strip()def _extract_nums(text: str):    return [_num(m) for m in XML_RE.findall(text)]def cor_reward(completions, **kw):    answers = kw.get("answer") or kw.get("answers") or []    rewards = []    for cand_list, gt inzip(completions, answers):        # 1) 收集 8 条回答里的所有 <answer>…</answer> 数字        nums = [            n            for c in cand_list            for n in _extract_nums(c["content"])        ]        # 2) 若一个数字都没抓到 → 直接 0 分        ifnot nums:            rewards.append(0.0)            continue        # 3) 群组投票:出现次数最多的数字        vote = Counter(nums).most_common(1)[0][0]        # 4) 评分:完全对 +2,差 1 +1,其余 0        diff = abs(int(vote) - int(gt)) if vote.isdigit() and gt.isdigit() else999        if   diff == 0: rewards.append(2.0)        elif diff == 1: rewards.append(1.0)        else:           rewards.append(0.0)    return rewards

详细步骤

1、_extract_nums()

2、组内投票(majority vote)

vote = Counter(nums).most_common(1)[0][0]

3、分级奖励 diff = |vote - ground_truth| – diff == 0 → +2 (完全正确) – diff == 1 → +1 (只差 1,也给部分梯度) – else     → 0 (远离真值) 这样 early 训练阶段更容易拿到非零 reward,梯度稠密,KL 更平滑。

输出示例

batch_size = 8cor_reward → [2,1,0,2,0,1,0,2]fmt_reward → [1,1,0,1,1,1,0,1]total_reward → [3,2,0,3,1,2,0,3]

新旧奖励函数对比:

结果:

群组 vote 的合理性研究

先投票再对真值”合不合理,要看我们希望奖励函数起什么作用。

合理方面

1、自洽性(Self-Consistency)的经验规律 OpenAI、Google 论文都表明: “同一 prompt 让模型多生成几条推理,用众数/平均值作为最终答案, 准确率往往高于单条输出。” 投票奖励把这个经验直接注入 RL:

2、梯度密度更高

3、利用并行生成的计算成本 既然我们已经花显存一次性生成了 8 条回答,把它们全都用来评奖要 比只看第一条更物超所值。

4、格式门控 + 数值投票分离 先用格式奖励约束输出形状,再用投票奖励评数值;两部分可独立调 权重,互不干扰。

局限性

    “集体跑偏” 如果模型内部存在系统性错误(8 条都写 41,但真值 42),投票仍会 选错。此时 reward 仍给 0 / 1,梯度作用有限。并列众数的歧义 Counter.most_common(1) 默认返回先出现的数字; 若票数打平,选择具有随机性,可能带来噪声。 → 可以设阈值:只有票数 ≥4 才用众数,否则 reward=0。差值阈值的 trade-off ‑ 差 1 给 1 分能 densify 梯度; ‑ 但如果阈值太宽(差 5 也给分)会削弱“完全正确”的驱动力。生成条数与开销 num_generations=8 对 A100 2B 模型还算轻;如果用更大的模型或者 更长 completion,生成 8 条会拖慢训练。

如何让投票更棒(进一步优化的可能性)

结论

因此,是否保留投票机制取决于:

训练结果指标解读

SFTTrainer 日志里出现字段:

SFT、GRPO 通用字段:

GRPOTrainer 特有字段:

备注:

    fmt_reward/mean ≥ 0.9 → 模板输出稳定。cor_reward/mean ≥ 1.2 → 30 % 以上完全正确(好)。kl < 0.3 → 更新稳定;若突涨,需减小学习率 / β。frac_reward_zero_std < 0.3 → 奖励信号足够密集。completions/clipped_ratio > 0.4 → 说明 128 token 不够,可调大。

奖励函数优化对训练效果实测

奖励函数优化前

训练

source .venv/bin/activateroot@a100vm:~/Gemma-2-2B-IT-GRPO# pwd/root/Gemma-2-2B-IT-GRPOroot@a100vm:~/Gemma-2-2B-IT-GRPO# python gemma-grpo2.py root@a100vm:~/Gemma-2-2B-IT-GRPO# python  gemma-instruct-grpo2.py

训练中的资源利用率

(Gemma-2-2B-IT-GRPO) root@a100vm:~/Gemma-2-2B-IT-GRPO# nvidia-smiSun Jun 1511:44:162025       +-----------------------------------------------------------------------------------------+| NVIDIA-SMI 560.35.05              Driver Version: 560.35.05      CUDA Version: 12.6     ||-----------------------------------------+------------------------+----------------------+| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC || Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. ||                                         |                        |               MIG M. ||=========================================+========================+======================||   0  NVIDIA A100 80GB PCIe          Off |   00000001:00:00.0 Off |                    0 || N/A   49C    P0            109W /  300W |   80793MiB /  81920MiB |     48%      Default ||                                         |                        |             Disabled |+-----------------------------------------+------------------------+----------------------+                                                                                         +-----------------------------------------------------------------------------------------+| Processes:                                                                              ||  GPU   GI   CI        PID   Type   Process name                              GPU Memory ||        ID   ID                                                               Usage      ||=========================================================================================||    0   N/A  N/A    449318      C   python                                      80780MiB |+-----------------------------------------------------------------------------------------+

评估

python3 -m venv ~/eval-envsource ~/eval-env/bin/activatepip install "torch>=2.1""transformers>=4.49" datasets tqdmpip install acceleratepython gsm8k-eval-tf2.py --model_dir gemma-grpo-onlypython gsm8k-eval-tf2.py --model_dir gemma-sft-grpo

评估脚本执行结果

纯 GRPO

----------------------------------------Input tokens  avg=140.5  max=269Output tokens avg=90.9  max=257Correct format     : 1142/1319 (86.6%)Plausibly correct  : 566/1319 (42.9%)Exact correct      : 559/1319 (42.4%)========================================

SFT+GRPO

----------------------------------------Input tokens  avg=140.5  max=269Output tokens avg=74.7  max=257Correct format     : 1192/1319 (90.4%)Plausibly correct  : 504/1319 (38.2%)Exact correct      : 500/1319 (37.9%)========================================(eval-env) root@a100vm:~/Gemma-2-2B-IT-GRPO# 

奖励函数优化后

训练

source .venv/bin/activateroot@a100vm:~/Gemma-2-2B-IT-GRPO# pwd/root/Gemma-2-2B-IT-GRPOroot@a100vm:~/Gemma-2-2B-IT-GRPO# python gemma-grpo3.py root@a100vm:~/Gemma-2-2B-IT-GRPO# python  gemma-instruct-grpo3.py

评估

python3 -m venv ~/eval-envsource ~/eval-env/bin/activatepip install "torch>=2.1""transformers>=4.49" datasets tqdmpip install acceleratepython gsm8k-eval-tf2.py --model_dir gemma-grpo-onlypython gsm8k-eval-tf2.py --model_dir gemma-sft-grpo

评估脚本执行结果

仅 GRPO

----------------------------------------Input tokens  avg=140.5  max=269Output tokens avg=92.2  max=257Correct format     : 1120/1319 (84.9%)Plausibly correct  : 665/1319 (50.4%)Exact correct      : 657/1319 (49.8%)========================================(eval-env) root@a100vm:~/Gemma-2-2B-IT-GRPO# 

SFT+GRPO

----------------------------------------Input tokens  avg=140.5  max=269Output tokens avg=75.5  max=257Correct format     : 1161/1319 (88.0%)Plausibly correct  : 506/1319 (38.4%)Exact correct      : 505/1319 (38.3%)========================================(eval-env) root@a100vm:~/Gemma-2-2B-IT-GRPO# 

奖励函数优化前后对比(仅仅对比 GRPO)

结论

    数值准确率 新奖励把完全正确率提升了约 7 个百分点,这是只奖励 exact-match 的直接收益格式合规率 基本持平(因为并没有优化格式奖励)后续细化奖励规则,增加训练 step 数,准确率有望继续提升。

白皮书推荐

开发者们,别掉队!大语言模型正以前所未有的速度重塑技术格局。微软最新发布《大语言模型(LLM)上手指南》白皮书,涵盖 Microsoft Copilot 副驾驶® 在代码编写、Debug、创意发想等方面的强大功能详细解说。

点击下方链接,在角色一栏中填写“开发者”即可领取专属开发者的技术文档。

info.microsoft.com/GC-DevOps-W…

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

强化学习 监督微调 奖励函数 模型训练 AI
相关文章