机器之心 03月02日
DeepSeek关键RL算法GRPO,有人从头跑通了,贡献完整代码
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

本文介绍了如何从头开始实现GRPO,基于Qwen2.5 - 1.5B - Instruct模型构建分布式强化学习流程,涵盖多个方面,包括基础设置、数据处理、函数定义、训练设置等,并展示了训练效果。

GRPO算法丢弃critic model,通过组内样本相对比较计算策略梯度

教程使用多个库和工具实现分布式训练,如PyTorch、Hugging Face等

数据准备中使用GSM8K数据集,通过对比答案为RL算法提供奖励

定义评估和奖励函数,评估模型并根据答案正确性分配奖励

实现GRPO算法的构建模块,进行训练设置和执行,展示训练效果

2025-03-02 11:54 北京

手把手教你从头跑通 GRPO

选自GitHub

作者:Andriy Burkov

机器之心编译


GRPO(Group Relative Policy Optimization)是 DeepSeek-R1 成功的基础技术之一,我们之前也多次报道过该技术,比如《DeepSeek 用的 GRPO 占用大量内存?有人给出了些破解方法》。


简单来说,GRPO 算法丢弃了 critic model,放弃了价值函数近似,转而通过组内样本的相对比较来计算策略梯度,从而有效降低了训练的不稳定性,同时提高了学习效率。


既然 GRPO 如此有效,那么,你知道如何从头开始实现 GRPO 吗?


近日,AI 工程师和技术作家 Andriy Burkov 发布了一份「从头开始写 GRPO 代码」的教程,其中介绍了如何基于 Qwen2.5-1.5B-Instruct 模型构建一个使用 GRPO 的分布式强化学习流程。


不过,在我们深入这份教程之前,先简单介绍一下它的作者。Andriy Burkov 算得上是 AI 领域的一位著名科普作家,在加拿大拉瓦尔大学取得了计算机科学博士学位,还曾发表过两本颇受欢迎的 AI 主题著作:《100 页语言模型书》和《100 页机器学习书》;书中一步步详实地介绍了相关概念,并附带了简明的实现代码。



接下来我们就来看看这份 GRPO 从头实现教程吧。



教程地址:https://github.com/aburkov/theLMbook/blob/main/GRPO_From_Scratch_Multi_GPU_DataParallel_Qwen_2_5_1_5B_Instruct.ipynb


从头编写 GRPO 代码

使用 Qwen2.5-1.5B-Instruct 的分布式实现


本教程将展示如何使用 GRPO 方法构建分布式强化学习(RL)流程,从而可以针对数学、逻辑和编程任务对语言模型进行微调。


首先需要明确,这些任务都存在一个唯一且正确的 ground truth 答案,可通过简单的字符串比较轻松加以验证。


GRPO 的发明者是 DeepSeek,最早是被用于微调 DeepSeek 的 R1 和 R1-Zero 模型 —— 它们可通过学习生成思维链(CoT)来更好地解决数学和逻辑问题。


本教程的目标是将通用语言模型 Qwen2.5-1.5B-Instruct 转换为数学问题求解器。我们将从头开始编写 GRPO 代码,然后将其与几个流行的库和工具集成起来,以实现分布式训练管道流程,包括:



本教程分为几个部分。首先是基本设置和导入,然后是数据格式化和答案提取、数据集准备、评估函数、奖励函数、训练设置和执行,最后加载和测试模型。此过程中,我们将从头实现 GRPO 算法。


Part 1:基础设置和导入


首先是安装并导入所有必要的模块。下面是导入库的一段代码截图。


部分代码截图。完整代码块参见 GitHub。


运行上述代码(参考项目完整代码),可以执行以下任务:



Part 2:数据格式以及答案提取


接下来,项目定义了数据格式,以及模型如何从输出和数据集中提取答案段落。


为了确保模型输出格式一致,项目还定义了一个系统提示。该提示指示模型生成包含 < reasoning > 和 < answer > 标签的输出。这一步通过两个函数完成:



部分代码截图。完整代码块参见 GitHub。


Part 3:数据准备


该项目使用 GSM8K 数据集进行训练。项目使用了该数据集中的示例来训练模型,基于强化学习(RL)训练范式,让模型生成多个问题解答样本,之后作者将这些解答与 GSM8K 示例中的标准答案进行对比,如果匹配,就为 RL 算法(GRPO)提供高奖励,然后更新模型权重,以增加模型下次获得高奖励的可能性。


实验过程是这样的。首先从 Hugging Face 加载数据集,然后格式化每个示例,包括系统提示和用户提示。这段实现代码中还定义了两个辅助函数:prepare_dataset 以及 build_prompt。


部分代码截图。完整代码块参见 GitHub。


Part 4:评估函数


评估对于跟踪模型的进展至关重要。因此作者定义了一些函数,从而可以在一组示例上对模型进行评估。该项目的评估函数执行以下任务:



在这段代码中,两个辅助函数 _extract_last_number 和 _extract_single_number 被用来从文本中提取数字。评估函数 evaluate_model 使用这些辅助函数来确定预测答案是否正确:


部分代码截图。完整代码块参见 GitHub。


Part 5:奖励函数


在强化学习中,奖励函数是必不可缺的,作者定义了两个奖励函数:


correctness_reward:这个函数根据生成的答案是否正确来分配奖励。采用两种方式:精确的字符串匹配和数值等价检查,将模型输出的答案与预期答案进行比较。完全匹配会获得更高的奖励(2.0),而基于数值等价的匹配会获得较小的奖励(1.5)。


format_reward:这个函数鼓励模型遵循所需的类似 XML 的输出格式。它为生成文本中存在 < reasoning>、</reasoning>、<answer > 和 </answer > 标签提供小额奖励。


部分代码截图。完整代码块参见 GitHub。


Part 6:从头开始实现 DataParallel GRPO


这一节,我们将从头实现 GRPO 算法的所有构建模块。首先,这里假设运行代码的机器至少有 2 台 GPU。为此,这里要使用 PyTorch 的 DataParallel API 来将策略模型放在多个 GPU 核心上,每个 GPU 核心都有该模型的一个副本。然后将批量数据分散在这些 GPU 核心上完成处理。


部分代码截图。完整代码块参见 GitHub。


Part 7:训练设置和执行


这一节,我们将所有组件组合在一起,完成设置并开始训练。


首先,加载预训练的模型和 tokenizer,准备评估数据,然后使用上面从头实现的 train_with_grpo 进行强化学习微调。


关键步骤包括:



下面的代码会执行以下功能:



GRPO 训练流程使用的超参数如下。


训练配置


以下参数设定了使用上面的 GRPO 算法实现强化学习微调运行的配置:



在微调之前和之后都会对模型进行评估,以衡量准确率的提高情况。最后,将微调后的模型保存到 grpo_finetuned_model 目录中。


部分代码截图。完整代码块参见 GitHub。


教程中还给出了详细的执行情况,可作参考。



下面我们也简单看看其训练过程。


首先,初始配置后,我们可以看到运行 GRPO 之前的准确度为 23.33%。



然后经过 500 步的 1 轮 GRPO 迭代,下图展示了相关的训练动态:



训练完成后,自然还需要对模型进行新一轮的评估。这里采用了 30 个评估样本来进行评估,以下展示了其中一个模型回答正确的示例:



整体表现如何呢?可以看到,经过一轮 GRPO 之后,Qwen-2.5-1.5B-Instruct 模型答对了 30 问题中的 27 题,实现了 90% 的准确度。相较于 GRPO 之前的 23.33%,可说是实现了性能飞跃。




上面两张图展示了模型的学习过程动态,可以看到:平均奖励在 2.25 左右就趋于稳定了(理论最大值为 0.8 + 2.0 = 2.8)。相比于另一处微调的 Qwen-2.5-0.5B-Instruct(获得的平均奖励为 1.4),这个数字相当高了,参阅:https://github.com/aburkov/theLMbook/blob/main/GRPO_Qwen_0_5_Instruct.ipynb


如果使用更大的模型并允许更长的生成时间,模型正确解答问题的能力还将进一步提升。但是,如果要训练更大的模型,不仅需要将数据分布在多台 GPU 上,还需要将模型分开放在多台 GPU 上,这需要用到 DeepSpeed 或 FSDP(完全分片数据并行)等模型并行工具。


下面加载和测试已经微调的模型:


完整代码见原笔记本


加载完成后测试一下,首先问问 1+1 等于几:



可以看到,模型反复思考了很多次,终于认定确实等于 2。


多次测试后还可以发现,该模型没有学会生成序列结束(EOS)token,因此即使在 </answer> token 之后,输出序列仍会继续。这是预期的行为,因为我们使用的奖励函数中没有包含一个用于停止生成的奖励。我们也没有执行监督微调步骤 —— 该步骤可以让模型学会在 </answer> 之后立即生成 EOS。


你对这篇代码密集的教程怎么看?有没有让你产生在自己的电脑上实现 GRPO 的想法?


© THE END 

转载请联系本公众号获得授权

投稿或寻求报道:liyazhou@jiqizhixin.com


阅读原文

跳转微信打开

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

GRPO 分布式训练 Qwen2.5 - 1.5B - Instruct 强化学习 数据处理
相关文章