掘金 人工智能 07月31日 11:43
用KL散度将Qwen3-8B向量模型知识蒸馏给小模型BGE-m3
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

本文详细介绍了如何利用强大的Qwen3-Embedding-8B作为教师模型,通过知识蒸馏技术提升BGE-m3向量模型的性能。文章阐述了基于KL散度的知识蒸馏方法,通过生成软标签来传递更精细的监督信号,从而缓解微调时因数据质量不高导致的过拟合和灾难性遗忘问题。实验结果显示,该方法在领域内数据集上显著提升了MAP@10指标10.20%,同时领域外性能仅有小幅下降,有效保留了模型的通用能力。文章提供了详细的环境准备、复现步骤及代码实现思路,并给出了详细的性能评测结果。

💡 **知识蒸馏的必要性**:在向量模型微调中,高质量数据集的构建是难点。当数据质量有限时,使用更强大的教师模型生成软标签进行知识蒸馏,能提供更精细的监督信号,有效缓解过拟合和灾难性遗忘等问题。

🚀 **核心蒸馏方法**:本文采用基于KL散度的知识蒸馏(DistillKLDivLoss),使学生模型学习教师模型的完整概率分布,以此传递更精细的“排序偏好”。损失函数为教师模型和学生模型计算的概率分布之间的KL散度,并使用温度系数τ来平滑分布。

🛠️ **实践复现流程**:整个流程包括数据准备(下载模型和数据集)、生成蒸馏数据集(使用vLLM加速)、训练学生模型(利用sentence-transformers库)和性能评测(对比蒸馏前后模型在领域内外的表现)。

📈 **实验效果显著**:通过Qwen3-Embedding-8B蒸馏到BGE-m3,在scidocs数据集上MAP@10指标提升了10.20%,领域外性能下降幅度低于2.5%,表明该方法在提升特定领域能力的同时,有效保持了模型的通用性。

⚖️ **软标签的优势**:相比硬标签的损失函数,软标签的KL散度蒸馏在保留通用能力和提升泛化能力方面展现出一定优势,但具体方法选择需结合数据量、质量及教师模型性能综合考虑。

BGE-m3 是开源社区下载量最高的向量模型之一,在RAG(检索增强生成)中用于检索出文档。通常微调的一个难点是需要构造质量比较高的数据集,如果合成数据或者人工标注不够准确,过于绝对的监督信号会放大这种不准确,造成过拟合、灾难性遗忘。而用一个更强大的教师模型生成软标签进行知识蒸馏,可以提供更加精细的监督信号来缓解这种问题,值得当你的数据质量有限时可以尝试一下。

近期,阿里云发布了Qwen3-Embedding系列SOTA向量模型。本文利用系列中最强的 Qwen3-Embedding-8B(教师模型),将其知识蒸馏到 0.6B 的 BGE-m3(学生模型)上。实验结果表明,通过该方法,学生模型在 scidocs-reranking 数据集上的 MAP@10 指标提升幅度达到 10.20%,领域外下降幅度低于2.5%。

项目代码已开源在 GitHub: github.com/kanhaoning/…

一、核心工具与方法

1.1 核心工具

1.2 训练方法:基于KL散度的知识蒸馏

在本次实践采用 sentence-transformers 库中更适合排序蒸馏任务的 DistillKLDivLoss

这种方法的核心思想是:学生模型学习通过教师模型所计算出的查询与文档相似度的完整概率分布。这能更精细地传递教师模型的“排序偏好”。

其损失函数是教师模型概率分布 P_t 与学生模型概率分布 P_s 之间的 KL散度 (Kullback-Leibler Divergence)

具体步骤如下:

    对于一个查询 Q 和一组文档 {P_1, P_2, ..., P_n}(包含一个正样本和多个负样本),教师模型和学生模型分别计算它们的相似度分数。使用带有温度系数 τSoftmax 函数,将这些分数转换成概率分布。温度 τ 可以平滑概率分布,让教师模型不那么“绝对”,从而为学生模型提供更丰富的学习信号。计算两个概率分布之间的KL散度作为损失。这里乘一个τ2\tau^2是为了当Softmax于平滑时补偿梯度,使得梯度保持稳定。

公式如下:

L(Q,{Pi})=τ2DKL(PtPs)=τ2iPt(Q,Pi)log(Pt(Q,Pi)Ps(Q,Pi))L(Q, \{P_i\}) = \tau^2 \cdot D_{KL}(P_t || P_s) = \tau^2 \cdot \sum_{i} P_t(Q, P_i) \cdot \log\left(\frac{P_t(Q, P_i)}{P_s(Q, P_i)}\right)

其中:

源码中为了避免数值下溢出,Ps(Q,Pi)P_s(Q, P_i)是使用更稳定的log_softmax实现的。这里各个符号的含义如下:

二、环境准备

首先,确保你已安装所有必要的库。

我在实验中使用的主要库版本如下:

Package                 Version----------------------- ------------------------torch                   2.6.0sentence-transformers   5.0.0transformers            4.53.1vllm                    0.8.4

如果尚未安装,可以使用 pip 命令安装。

pip install torch sentence-transformers==5.0.0 transformers==4.53.1 vllm==0.8.4# modelscope 用于方便地下载国内模型pip install modelscope

三、复现步骤

整个流程分为四步:生成教师分数 -> 构建训练数据 -> 训练学生模型 -> 性能评测。我已经将每一步都封装成了脚本,你只需要按顺序执行即可。

步骤 1:下载模型与数据集

首先,我们需要把教师模型、学生模型和数据集准备好。

1. 下载 BGE-m3 (学生模型)

from modelscope import snapshot_downloadmodel_dir = snapshot_download('BAAI/bge-m3', cache_dir='/path/to/your/models')

2. 下载 Qwen3-Embedding-8B (教师模型)

from modelscope import snapshot_downloadmodel_dir = snapshot_download('Qwen/Qwen3-Embedding-8B', cache_dir='/path/to/your/models')

3. 下载 Scidocs 数据集

注意这里直接拉取数据集可能会报错,建议手动下载:

访问 MTEB/scidocs-reranking,手动下载validation.jsonl.gztest.jsonl.gz(均为3.5MB)到路径Embedding-Distillation/dataset_scidocs,然后执行以下代码解压:

gunzip validation.jsonl.gzgunzip test.jsonl.gz

解压后Embedding-Distillation/dataset_scidocs目录下应该有 validation.jsonltest.jsonl 两个文件。为了降低复现成本,将使用这里的validation.jsonl为训练集,test.json为领域内测试集。

3. 下载 Stackoverflowdupquestions 数据集

访问 MTEB/stackoverflowdupquestions-reranking,手动下载test.jsonl.gz(1.35MB)到路径Embedding-Distillation/dataset_stackoverflowdupquestions,然后执行以下代码解压:

gunzip test.jsonl.gz

解压后Embedding-Distillation/dataset_stackoverflowdupquestions目录下应该有test.jsonl 文件。为了测试灾难性遗忘,将使用这里的test.jsonl为领域外测试集。

步骤 2:生成蒸馏数据集

这一步,我们的目标是用更强大的教师模型 Qwen3-Embedding-8B 为数据集中的每一个 (query, passage) 对计算一个相似度分数。由于数据集很大(训练集有约8.8万个passage),直接用 transformers 跑会非常慢。为了大幅提高速度,我们采用 vLLM 框架进行推理加速。

执行脚本:

bash generate_distillation_data.sh

这个脚本会调用 generate_distillation_data.py。在运行前,请修改脚本内的 --teacher_model_path,使其指向你下载好的 Qwen3-Embedding-8B 模型路径。

实现思路

generate_distillation_data.py 的逻辑清晰高效,我们可以将其拆解为以下四步:

1. 读取与展开数据脚本首先读取原始的 validation.jsonl 文件。文件中的每一行都包含一个 querypositive 列表和 negative 列表。为了给 DistillKLDivLoss 准备输入,脚本会使用 itertools.product (笛卡尔积)将positive和negative文档两两组合获取大量的 (query, positive_doc, negative_doc) 三元组。

# 关键代码片段 1: 使用笛卡尔积生成三元组from itertools import productfor pos_item, neg_item in product(positives, negatives):    unique_texts.add(pos_item)    unique_texts.add(neg_item)    triplets.append({'query': query, 'positive': pos_item, 'negative': neg_item})

2. 批量向量化为了最大化效率,脚本会收集所有不重复的文本(包括所有 query、positive 和 negative),然后用 vLLMmodel.embed() 方法一次性将它们全部向量化。另外需要注意的是,Qwen3-Embedding官方代码会给Query加上一句任务的Instruct,比如Given a web search query, retrieve relevant passages that answer the query,但是实测会降低蒸馏效果,所以没有使用。

# 关键代码片段 2: 使用 vLLM 进行高效批量编码input_texts = list(unique_texts)outputs = model.embed(input_texts)all_embeddings = [torch.tensor(o.outputs.embedding) for o in outputs]

3. 计算教师分数获得所有文本的向量后,脚本会为每一个三元组计算教师模型给出的余弦相似度分数。通常对于标准化的向量,余弦相似度可以使用更简单的内积等效实现。

# 关键代码片段 3: 计算(query, pos)和(query, neg)的相似度emb_q = text_to_embedding.get(q_text)emb_p = text_to_embedding.get(p_text)emb_n = text_to_embedding.get(n_text)sim_pos = similarity(emb_q, emb_p)sim_neg = similarity(emb_q, emb_n)

4. 生成蒸馏文件最后,脚本将每个三元组和它对应的两个相似度分数打包成一个新的 JSON 对象,写入到输出文件中。这个格式正是 DistillKLDivLoss 所需要的。

# 关键代码片段 4: 构建最终的输出记录record = {    "query": q_text,    "positive": p_text,    "negative": n_text,    "label": [sim_pos, sim_neg]  # "软标签"}f_out.write(json.dumps(record, ensure_ascii=False) + '\n')

注意这里一个样本只有(query, positive, negative)三元组,将公式中的候选文档集合 {P_i} 简化为了 {positive, negative} 两个文档。教师模型生成的 label[M_t(Q, P_positive), M_t(Q, P_negative)]。实践时不用拘泥于只用单个negative的这种格式,sentence-transformers实现的DistillKLDivLoss也是支持一个样本有多个negative字段的,比如{"query": "q_text", "positive": "p_text", "negative1": "n_text1", "negative2": "n_text2", "label": [0.6, 0.3, 0.1]},但是这样会增大调参的难度。

最终会在Embedding-Distillation/dataset_scidocs路径下生成输入模型训练的jsonl文件validation_kldiv_distill.jsonl,这是一个样本例子:

{"query": "Beauty eMakeup: A Deep Makeup Transfer System", "positive": "Learning Hierarchical Features for Scene Labeling", "negative": "Registration with the Point Cloud Library: A Modular Framework for Aligning in 3-D", "label": [0.6058026552200317, 0.5828931331634521]}

步骤 3:训练学生模型

我们现在开始训练学生模型 BAAI/bge-m3

执行脚本:

bash train.sh

这个命令会执行 train.py 脚本。在运行前,请务必检查并修改 train.sh 中的几个关键路径和参数:

该脚本使用 sentence-transformers 库,其核心逻辑主要有以下几步:

实现思路

1. 加载学生模型和数据集脚本首先加载预训练的学生模型 bge-m3。这是一个 SentenceTransformer(双编码器)模型,它能将文本高效地映射到向量空间。接着,它加载我们在上一步生成的 .jsonl 蒸馏数据集。

# 文件: train.py# 1. 加载学生模型student_model = SentenceTransformer(model_args.student_model_name_or_path)# 2. 加载蒸馏数据集# 每条样本包含 query, positive, negative 和 label 字段train_dataset = load_dataset("json", data_files=model_args.train_dataset_path)["train"]

2. 定义 DistillKLDivLoss 损失函数该损失函数执行的一些细节如下:

使用代码如下:

# 文件: train.pyfrom sentence_transformers import losses# 定义KL散度蒸馏损失train_loss = losses.DistillKLDivLoss(    model=student_model,    temperature=model_args.temperature # 温度参数,默认为2.0)

3. 初始化并运行 SentenceTransformerTrainer最后,我们将所有组件,包括学生模型、训练参数、蒸馏数据集和损失函数都传递给 SentenceTransformerTrainer。这个trainer都实现了各种训练细节,如批处理、梯度累积、学习率调度、日志记录和模型保存。

# 文件: train.pyfrom sentence_transformers.trainer import SentenceTransformerTrainertrainer = SentenceTransformerTrainer(    model=student_model,    args=training_args,    train_dataset=train_dataset,    loss=train_loss,)# 启动训练trainer.train()

关键参数说明:

训练开始后,你可以通过日志实时监控训练进度。训练完成后,最终模型和检查点将保存在 --output_dir 指定的目录中,用于下一步性能评测。

步骤 4:性能评测

训练完成后,让我们来验证以下两个问题:

    学生模型在领域内(scidocs)上的性能提升了多少?学生模型在领域领域外,比如(stackoverflowdupquestions)有没有大幅下降?

现在使用蒸馏前后的 bge-m3 模型,分别对领域内和领域外两个测试集用sentence-transformers实现的RerankingEvaluator进快速进行评测。

执行脚本:

bash evaluation.sh

这个命令会调用 evaluation.py 脚本来执行完整的评测流程。在运行前,请务必修改 evaluation.sh 中的模型路径:

结果

等待 evaluation.py 脚本运行完毕后,生成如下表格:

指标蒸馏前蒸馏后绝对变化相对变化(%)
领域内 (scidocs)
map0.77440.8534+0.0790+10.20
mrr@100.93210.9554+0.0233+2.50
ndcg@100.82960.8973+0.0676+8.15
领域外 (stackoverflow)
map0.51680.5040-0.0129-2.49
mrr@100.52400.5116-0.0124-2.37
ndcg@100.59040.5774-0.0129-2.19

(注:你的具体数值可能会因训练的随机性、超参数设置等有微小差异,但总体趋势应保持一致。)

从结果中我们可以得出两个结论:

    领域内性能显著提升:在目标任务 scidocs 数据集上,3个关键指标都获得了明显增长。其中 MAP 指标提升了 10.20%,这证明了知识蒸馏的有效性。学生模型成功地从教师模型 Qwen3-Embedding-8B 那里学到了更精细的排序知识,使其在专业领域的文档排序能力变得更强。

    通用能力基本保持:在领域外的 stackoverflowdupquestions 数据集上,模型的性能只有一定的下降(MAP指标下降 2.49%),但是不算灾难性遗忘。学生模型在提升特定领域能力的同时,基本保留了其原有的通用能力。

这个方法整体实测下来,用软标签的损失函数DistillKLDivLoss整体上比硬标签的MultipleNegativesRankingLoss、TripletLoss在保留通用能力和提升泛化能力上都有一定优势,但是也有翻车的时候,比如教师模型表现不佳的数据集。具体实践得结合数据数量、质量、教师模型的性能考虑方法选用。

五、参考文献

Qwen3 Embedding: Advancing Text Embedding and Reranking Through Foundation Models

Distilling Dense Representations for Ranking using Tightly-Coupled Teachers


如果这篇文章对你有帮助,请给我的Github项目点个star吧:github.com/kanhaoning/…


本文首发于知乎平台,原文链接:zhuanlan.zhihu.com/p/193369385…

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

知识蒸馏 向量模型 BGE-m3 Qwen3-Embedding RAG
相关文章