掘金 人工智能 05月16日 11:03
【LLM RL】论文分享No.9:SWiRL(Multi-Step)
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

该论文提出了一种名为SWiRL的逐步强化学习方法,旨在解决大型语言模型在处理需要多步推理和工具使用的复杂任务时遇到的挑战。SWiRL通过合成数据生成和离线RL优化,提升模型在多跳问答、数学解题等任务中的表现。该方法的核心在于关注推理过程,弱化结果标签的影响,从而使模型更好地学习如何分解问题、适时调用工具并整合结果。实验结果表明,SWiRL在多个数据集上均取得了显著的性能提升,尤其是在多步推理和工具使用方面。

🧪 **合成数据生成:** SWiRL的第一阶段通过迭代提示语言模型并配备工具(如搜索引擎、计算器)来生成多步轨迹,模拟模型在解决复杂问题时的推理过程。该阶段还包括过程过滤和结果过滤两种策略,用于筛选高质量的训练数据。

🎯 **逐步强化学习优化:** SWiRL的第二阶段采用逐步强化学习的方法,在每一步骤中根据模型采取的行动和环境的反馈,给予奖励。通过这种方式,模型可以学习如何在多步推理过程中做出更明智的决策,从而提高整体性能。

📈 **实验结果验证:** 论文通过在HotPotQA、GSM8K等多个数据集上的实验,验证了SWiRL的有效性。实验结果表明,SWiRL在同分布和不同分布任务上均能提升模型性能,且在多步推理和工具使用方面表现出色。

🔑 **过程过滤的重要性:** 研究发现,相比于结果过滤,过程过滤对下游任务的准确率提升更为关键。这意味着在训练过程中,注重数据提炼的过程比单纯追求结果的正确性更为重要。SWiRL能够从过程合理但结果可能不对的数据中学习,从而获得更好的效果。

论文名称:Synthetic Data Generation & Multi-Step RL for Reasoning & Tool Use

论文:arxiv.org/abs/2504.04…

机构:斯坦福大学计算机科学系 + Google DeepMind

时间线:2025/04/07(submitted)

简介

这篇论文提出一个逐步强化学习(SWiRL)方法,用来解决LLM在处理如多跳问答、数学解题等需要多步推理和工具使用的复杂任务时面临的挑战,这些任务要求模型分解问题、适时调用工具并合成结果,而传统RL方法(如RLHF、RLAIF)聚焦单步优化无法有效应对。所以SWiRL通过合成数据和离线RL优化的策略,能够提升模型多步推理和工具使用能力。

Stage-1:Synthetic Data Generation

流程概述

图1展示的是SWiRL的第一阶段:合成数据生成 。流程如下:

① 过程过滤:用LLM判断所有步骤的“process label”是不是都是正向的(合理的 ),如果是,这些数据就成为“Process - filtered data”(过程过滤后的数据 )。

② 结果过滤:判断最终答案“Final Answer”和“golden answer”(标准答案 )相比是不是正确的,如果正确,这些数据就成为“Outcome - filtered data”(结果过滤后的数据 )。

实现细节

① 通过为语言模型配备工具(如搜索引擎、计算器 ),迭代提示模型生成多步轨迹。

② 每一步模型需选择调用工具或给出最终答案,还可生成推理链(CoT)。

③ 若调用工具,会解析并在环境中执行,结果供下一步使用。

① 用 HotPotQA 训练集 10000 个多步问题生成 50000 条合成轨迹(每个问题 5 条 ),构建多跳问答数据集。

② 用 GSM8K 训练集 7500 个问题生成 37500 条合成轨迹,构建数学推理数据集。

③ 针对 HotPotQA,过滤掉通常单步查询就能回答的 “Easy” 问题,并分别为 HotPotQA(最多 5 步 )和 GSM8K(最多 10 步 )问题设置最大步数限制。

① 考虑四种过滤策略,即不过滤、过程过滤(根据先前步骤判断每步合理性 )、结果过滤(仅依据最终答案是否匹配标准答案 )、过程和结果过滤(取两者交集 )。

② 之前有工作(DeepSeek-R1)表明结果过滤的合成数据在单步RL和SFT中有良好表现,但本论文发现对于SFT而言,多步轨迹的正确性过滤有效且关键,但 SWiRL 与SFT不同,不单纯以结果论,更注重推理过程中的合理性,利用这些过程合理但结果可能不对的数据,反而能让模型的学习效果更优 。

Stage-2:Step-Wise RL Optimization

图2展示的是SWiRL的第二阶段:逐步强化学习优化 。

还是从“Prompt”(粉红色框 )开始,这里的例子是“Who is older Glenn Hughes or Ross Lynch?”,也就是提出一个问题,流程如下:

① 模型采取“Action 1”(浅蓝色框 ),比如决定去查询Glenn Hughes的年龄,具体表述为“To figure out who is older, I should first search for age of Glenn Hughes. <search_query>age of Glenn Hughes</search_query>” 。

② 然后会有“Env Response”(绿色框 ),不过图里没具体展示这一步响应内容。之后会有一个“Reward”(黄色框 ),奖励模型会根据前面的步骤,给这一步的操作打分。

① 基于第一步的情况,模型采取“Action 2” ,像这里是说“Next, I should find out what Ross Lynch’s age is. <search_query>Ross Lynch age</search_query>” ,也就是接着去查询Ross Lynch的年龄。

② 同样会有对应的“Env Response” ,然后奖励模型再根据前面步骤,给这一步操作打分,给出“Reward” 。

① 经过前面一系列步骤,到最后“Action N” ,这里就是给出响应(Response ),比如“Given the results of my previous searches, I have enough information to answer the question. Glenn Hughes” ,意思是根据前面查询结果,得出答案是Glenn Hughes 。

② 最后奖励模型还是会根据前面所有步骤,给这最后一步操作打分,给出“Reward” ,并且这个过程中奖励模型不参考标准答案 。

总体来说,就是模型在面对问题时,一步步采取行动、获取反馈,然后奖励模型不断给每一步行动打分,通过这样的方式来优化模型在多步推理和工具使用上的能力 。

Stage-3:Step-Wise Inference-time Evaluation

图3展示的是SWiRL多步推理过程 。

从“Prompt”(粉红色框 )开始,用户提出问题“Please help me answer the following question in just a few words...” ,这里具体问题是水果摊贩相关的数学问题:一个水果摊贩花80美元买了50个西瓜,以25%的利润卖掉,问每个西瓜卖多少钱。同时提示如果需要可以使用计算器(用标签表示 ),并告知结果输出格式。流程如下:

① 模型做出“Action 1”(浅蓝色框 ),分析出摊贩利润是25% ,利润金额为0.25 * 80 ,并给出原始价格80美元和利润计算式0.25*80 。

② 接着“Env response”(绿色框 ),用户计算出0.25 * 80 = 20 。

① 模型做出“Action 2” ,表示西瓜原始价格80美元加上利润20美元,得到总售价,给出计算式80+20 。

② “Env response”中,用户计算出80 + 20 = 100 。

① 模型继续推进,比如考虑西瓜数量,给出计算每个西瓜售价的式子100 / 50 。

② “Env response”里用户计算100 / 50 = 2 。

③ 最后到“Action N” ,模型得出最终答案,用标签标注为2 ,也就是每个西瓜售价是2美元 。

整个过程就是模型在面对问题时,通过多步操作,借助用户反馈,一步步推理计算,最终得出答案 。

实验结果

训练集

前文已经说到了,用的是HotPotQA和GSM8K,会在评测集部分简单介绍一下每个数据集的关键特征。

评测集

实际评测的时候,对每个问题迭代提示模型调用工具或给出最终答案,问答数据集最多查询5次,数学推理数据集最多查询10次。

关键结论

①【图4】通过对比4种过滤机制对下游任务准确率的影响,发现仅过程过滤普遍带来最高准确率,说明注重数据提炼过程更关键;除MuSiQue外,正确性过滤常降低性能,表明SWiRL能从正负例学习,凸显结果过滤相对不重要,且其过程RL方法可从答案错误轨迹有效学习。

②【表2】在HotPotQA或GSM8K的合成轨迹上进行SWiRL微调,能提升模型在同分布和不同分布任务上的性能,且在不同领域和工具使用的任务间进行训练,可有效提高模型的通用多步推理和工具使用能力。

③【图5】通过对比SFT和SWiRL性能,显示SFT在各数据过滤策略下均不如SWiRL,SFT适合用过程和结果过滤的数据,而SWiRL从仅过程过滤的数据中学习效果最好,因其能按步骤优化奖励,提升规划和泛化能力。

④【图6】SWiRL的多步工具使用推理能提升基础模型和SWiRL微调模型的性能,且SWiRL微调模型提升更显著,即便没有工具,SWiRL训练也能增强模型分解复杂问题的能力

⑤【图7】随着微调数据集规模增加,SWiRL在多步推理任务上的模型性能持续提升,100个数据点的数据集难以让模型有效泛化,1000个数据点时模型性能显著改善,10000个数据点能进一步提升性能,且较小模型(2b和9b)在领域内受益于SWiRL,但泛化能力不如27b模型。

⑥ 【表3】经过多步RL优化后,SWiRL模型在分布内和分布外任务中,每一步的平均正确性均高于基线模型,表明其最终准确率的提升得益于更好的多步推理

总结

这篇论文的核心思路就是,将用CoA数据做RL的时候,相比于之前采用的单步RL,加了对每轮工具调用过程中的RL,引导模型关注推理过程,弱化结果标签的影响,在公开数据上做的实验也是证明了一些效果。但是,在实际落地的时候,困难还是挺多的,比如:

① 数据获取与标注瓶颈:SWiRL 依赖高质量的过程性标注数据(如分步推理轨迹),但这类数据的人工标注成本极高且效率低下,太难了。

②训练成本高:每条数据都是多步的训练方法,对于资源以及时间的消耗进一步提高。

③奖励函数设计与对齐困难:SWiRL 的性能高度依赖奖励函数的合理性,但多步骤推理任务的奖励信号往往稀疏且难以定义的。

不过,瑕不掩瑜,是一篇值得读的论文,推荐一波。

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

SWiRL 强化学习 多步推理 工具使用
相关文章