掘金 人工智能 07月15日 10:16
将重排序大模型Qwen3-Reranker-8B的知识蒸馏到小模型BGE-reranker-v2-m3上
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

本文分享了一种低成本优化RAG(检索增强生成)中重排序模型的方法,通过知识蒸馏技术,将强大的Qwen3-Reranker-8B模型的知识迁移到较小的BGE-reranker-v2-m3模型上。实验结果表明,该方法在stackoverflowdupquestions-reranking数据集上取得了显著的性能提升,MRR@10指标提升近20%。该方案无需人工标注,降低了成本,为RAG优化提供了新的思路。

💡 知识蒸馏是一种将大型教师模型的知识迁移到小型学生模型上的技术,本研究利用该技术优化RAG系统中的重排序模型,降低了优化成本。

⚙️ 核心流程包括:使用教师模型生成logit分数、构建训练数据、训练学生模型和性能评测。其中,使用vLLM加速推理过程,显著提高了效率。

📈 实验结果表明,经过知识蒸馏后,BGE-reranker-v2-m3模型在stackoverflowdupquestions-reranking测试集上的各项重排指标均有显著提升,MRR@10指标提升了19.96%。

🛠️ 整个流程被封装成脚本,用户只需按顺序执行即可完成模型优化。项目代码已开源在GitHub上,方便用户复现和应用。

🤔 本文的核心在于,通过MarginMSE损失函数,让学生模型学习区分“好”与“更好”的差异,而非完全复现教师模型的分数,从而保留了学生模型自身的特性。

BGE-reranker-v2-m3 是一个很好用的重排序模型,在RAG(检索增强生成)中用于进一步优化检索出的文档。但是也存在一个痛点:用大模型合成、甚至人工标注 (query, positive, negative) 三元组数据用于训练微调,过程麻烦且成本较高。

最近,阿里云发布了Qwen3-reranker系列SOTA重排序模型。本文将分享一个低成本的优化方案:利用系列中最强的 Qwen3-Reranker-8B(教师模型),将其知识蒸馏到 0.6B 的 BGE-reranker-v2-m3(学生模型)上。实验结果表明,通过该方法,学生模型在 stackoverflowdupquestions-reranking 数据集上的 MRR@10 指标提升幅度达到 19.96%。

项目代码已开源在 GitHub: github.com/kanhaoning/…

一、核心工具与方法

1.1 核心工具

1.2 训练方法:MarginMSE知识蒸馏

训练的目标是让学生模型学会模仿教师模型对不同样本打分的差异,而不是直接学习教师模型打出来的分数。具体来说,就是让学生模型对于(查询,更相关文档)和(查询,不相关文档)这两个组合的相关性分数之差,尽可能地接近教师模型给出的分数差。

使用的损失函数是 MarginMSE,公式如下:

L(Q,P+,P)=MSE((Ms(Q,P+)Ms(Q,P)),(Mt(Q,P+)Mt(Q,P)))L(Q, P_+, P_-) = \text{MSE}( (M_s(Q, P_+) - M_s(Q, P_-)), (M_t(Q, P_+) - M_t(Q, P_-)) )

其中:

这种方法不要求学生模型完全复现教师模型的分数,只要求它学会区分“好”与“更好”的差异,这使学生模型可以学习与自身打分差异较大的教师模型的同时而在保留自身打分的特性。相比于SFT,知识蒸馏不需要人工标注,尤其适合在有大量垂直领域的原始数据,但是没有高质量的标注的场景。

二、环境准备

首先,确保你已安装所有必要的库。

我在实验中使用的主要库版本如下:

Package                 Version----------------------- ------------------------torch                   2.6.0sentence-transformers   5.0.0transformers            4.53.1vllm                    0.8.4

如果尚未安装,可以使用 pip 命令安装。

pip install torch sentence-transformers==5.0.0 transformers==4.53.1 vllm==0.8.4# modelscope 用于方便地下载国内模型pip install modelscope

三、复现步骤

整个流程分为四步:生成教师分数 -> 构建训练数据 -> 训练学生模型 -> 性能评测。我已经将每一步都封装成了脚本,你只需要按顺序执行即可。

步骤 1:下载模型与数据集

首先,我们需要把教师模型、学生模型和数据集准备好。

1. 下载 BGE-reranker-v2-m3 (学生模型)

from modelscope import snapshot_downloadmodel_dir = snapshot_download('BAAI/bge-reranker-v2-m3', cache_dir='/path/to/your/models')

2. 下载 Qwen3-Reranker-8B (教师模型)

from modelscope import snapshot_downloadmodel_dir = snapshot_download('Qwen/Qwen3-Reranker-8B', cache_dir='/path/to/your/models')

3. 下载 Stack Overflow 数据集访问 MTEB/stackoverflowdupquestions-reranking,手动下载 train.jsonl.gztest.jsonl.gz,然后解压。

gunzip train.jsonl.gzgunzip test.jsonl.gz

将解压后的 train.jsonltest.jsonl 文件放到你的项目目录下。

步骤 2:生成教师模型 Logits 分数

这一步,我们的目标是用更强大的教师模型 Qwen3-Reranker-8B 为数据集中的每一个 (query, passage) 对计算一个logit分数。由于数据集很大(训练集有约60万个query-passage pair),直接用 transformers 跑会非常慢。为了大幅提高效率,我们采用 vLLM 框架进行推理加速。

执行脚本:

bash generate_logits.sh

这个脚本会调用 generate_logits.py。在运行前,请修改脚本内的 --model_path,使其指向你下载好的 Qwen3-Reranker-8B 模型路径。

实现原理

Qwen3-Reranker 将“重排序”任务转化为了一个“生成”任务。输入判断查询(Query)和文档(Document)是否相关的提示词后,预测下一个词是 "yes" 还是 "no",以此来判断文档与查询的相关性。

官方给出的相关性分数计算公式是基于Softmax的概率:

score(q,d)=elogit(yesI,q,d)elogit(yesI,q,d)+elogit(noI,q,d)\text{score}(q, d) = \frac{e^{logit(\text{yes}|I,q,d)}} {e^{logit(\text{yes}|I,q,d)} + e^{logit(\text{no}|I,q,d)}}

这个公式将模型的输出转换为一个 01 之间的概率值。

在知识蒸馏(knowledge distillation)的场景中,我们需要 Qwen3-Reranker 这种 decoder-only 架构模型提供类似于 cross-encoder 架构中的等效 logit 值。为此,需要对原始的概率得分进行反 sigmoid 变换,即:

logit=log(score1score)\text{logit} = \log\left(\frac{\text{score}}{1 - \text{score}}\right)

最终,用于知识蒸馏的 logit 值可以通过以下公式获得:

logit=logP(yesI,q,d)logP(noI,q,d)\text{logit} = \log P(\text{yes} \mid I, q, d) - \log P(\text{no} \mid I, q, d)

其中,P(yes) 和 P(no) 分别表示模型生成“yes”和“no”的条件概率。得益于 vLLM 提供的接口,我们可以高效地获取每个 token 的对数概率值。接下来我们将基于 vLLM 框架实现上述核心逻辑代码。

代码实现

1. 构造Qwen3-Reranker模型输入

我们首先需要按照 Qwen3-Reranker 指定的模板,将 (query, passage) 对格式化为一段提示词。

# 文件: generate_logits.pydef format_and_tokenize_inputs(    tokenizer: AutoTokenizer,    queries: List[str],    docs: List[str],    instruction: str,    max_length: int) -> List[TokensPrompt]:    """使用 apply_chat_template 格式化并 tokenize 输入"""    messages = []    for query, doc in zip(queries, docs):        # 这是模型要求的标准对话格式        message = [            {"role": "system", "content": "Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\"."},            {"role": "user", "content": f"<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}"}        ]        messages.append(message)        # 使用 tokenizer 的模板功能,高效地将文本转换为 token IDs    templated_messages = tokenizer.apply_chat_template(        messages,        tokenize=True,        add_generation_prompt=True, # add_generation_prompt=True 会自动添加 assistant 角色的起始符        enable_thinking=False    )        # 截断超长序列并转换为 vLLM 接受的 TokensPrompt 格式    processed_messages = [ele[:max_length] for ele in templated_messages]    final_messages = [TokensPrompt(prompt_token_ids=ele) for ele in processed_messages]    return final_messages

2. 计算 Logit 分数

接收 vLLM 的推理结果,并从中提取 "yes" 和 "no" 的对数概率(logprobs),最后计算它们的差值。

# 文件: generate_logits.pydef compute_scores_vllm(    model: LLM,    tokenizer: AutoTokenizer,    sampling_params: SamplingParams,    batch_queries: List[str],    batch_docs: List[str],    instruction: str,    max_length: int) -> List[float]:    """计算分数的函数"""    # 获取 'yes' 和 'no' 两个词对应的 token ID    true_token = tokenizer("yes", add_special_tokens=False).input_ids[0]    false_token = tokenizer("no", add_special_tokens=False).input_ids[0]    # 1. 格式化输入    tokenized_batch = format_and_tokenize_inputs(tokenizer, batch_queries, batch_docs, instruction, max_length)    # 2. 使用 vLLM 并行推理    outputs = model.generate(tokenized_batch, sampling_params=sampling_params, use_tqdm=False)    scores = []    for output in outputs:        # 3. 从推理结果中提取 logprobs        # 我们只关心生成的第一个 token,所以取 logprobs[-1]        final_logprobs = output.outputs[0].logprobs[-1]        # 4. 获取 'yes' 和 'no' 的对数概率,如果某个词不存在于 top logprobs 中,给一个很小的默认值        true_logprob = final_logprobs.get(true_token, -10.0)        if not isinstance(true_logprob, float): # vLLM 可能返回 Logprob 对象            true_logprob = true_logprob.logprob        false_logprob = final_logprobs.get(false_token, -10.0)        if not isinstance(false_logprob, float):            false_logprob = false_logprob.logprob        # 5. 核心:计算yes和no的对数概率差值作为Logit        logit_diff = true_logprob - false_logprob        scores.append(logit_diff)            return scores

脚本与产出

generate_logits.sh 脚本负责调用上述 Python 代码,并传入必要的参数,如模型路径、输入文件名和批处理大小。

# generate_logits.sh 内容#!/bin/bash# 使用 vLLM + Qwen3-Reranker-8B 生成训练/测试数据的 logits 分数python generate_logits.py \  --model_path your_path_to/Qwen3-Reranker-8B \  --input_files train.jsonl test.jsonl \  --output_suffix _distill_qwen3_8b_vLLMlogit \  --batch_size 8 \  --max_model_len 8192 \  --gpu_memory_utilization 0.9 \  --task_instruction "Given a web search query, retrieve relevant passages that answer the query"

脚本运行成功后,将生成两个新的jsonl文件,分别对应构建训练集、测试集所需的logit分数:

每一行是一个(query, passage, score)pair,以下是一个具体例子:

{"query": "String isNullOrEmpty in Java?", "passage": "Java equivalent of c# String.IsNullOrEmpty() and String.IsNullOrWhiteSpace()", "score": 0.875}

步骤 3:构建训练样本

接下来,我们需要将上一步生成的Logit分数文件,转换为 MarginMSE 损失函数需要的三元组格式 (query, positive, negative, score_diff)

执行脚本:

bash create_triplets.sh

该脚本会调用 create_triplets.py,它会为每个 query 下的高分 passage(正例)匹配若干个低分 passage(负例),并计算它们的分数差。这一步是由Gemini生成的采样方法,不一定是最优解。

产出:此步骤会生成最终的训练和评估文件:

以下是一条数据的具体例子:

{"query": "String isNullOrEmpty in Java?", "positive": "Java equivalent of c# String.IsNullOrEmpty() and String.IsNullOrWhiteSpace()", "negative": "isLocalHost(String hostNameOrIpAddress) in Java", "score": 6.231092929840088}

步骤 4:训练学生模型

现在我们开始训练(蒸馏)学生模型 bge-reranker-v2-m3

执行脚本:

bash train.sh

此脚本会调用 train.py,使用 sentence-transformers 框架提供的 MarginMSELoss 来进行微调。

关键参数说明:

训练过程日志会显示 eval_loss,我们可以依据此指标来保存最佳模型。

步骤 5:性能评测与对比

训练完成后,我们评测一下效果

执行脚本:

bash evaluate.sh

该脚本会调用 evaluation.py,分别评估蒸馏前蒸馏后的模型在测试集上的性能,并清晰地展示对比结果。

关键参数说明:

四、结果分析

经过蒸馏,bge-reranker-v2-m3 模型在 stackoverflowdupquestions-reranking 测试集上的各项重排指标都获得了明显提升

指标 (Metric)蒸馏前 (Before)蒸馏后 (After)绝对提升相对提升
MAP0.4720610.565317+0.093256+19.76% 🚀
MRR@100.4782340.573779+0.095545+19.98% 🚀
NDCG@100.5472840.639033+0.091748+16.76% 🚀

从上表可以看出所有核心评估指标(MAP, MRR, NDCG)均有16-20%的增长。这表明在这个场景有效将知识从大模型蒸馏到小模型上(但是整体分数还是较低,可能这个数据集比较有难度)

五、参考文献

Qwen3 Embedding: Advancing Text Embedding and Reranking Through Foundation Models

Improving Efficient Neural Ranking Models with Cross-Architecture Knowledge Distillation


如果这篇文章对你有帮助,请给我的Github项目点个star吧:github.com/kanhaoning/…


本文首发于知乎平台,原文链接:zhuanlan.zhihu.com/p/192822324…

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

RAG 知识蒸馏 重排序模型 BGE-reranker-v2-m3 Qwen3-Reranker-8B
相关文章