掘金 人工智能 2024年07月08日
基于Sentence Transformer微调向量模型
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

Sentence Transformer库升级到V3版本后,模型训练和微调变得更加简便。本文将以官方教程为例,详细介绍Sentence Transformer V3的模型微调过程,涵盖了数据准备、损失函数选择、训练参数配置、评估器使用和训练器构建等关键步骤。

🤔 **数据准备**: Sentence Transformer V3支持多种数据格式,如json、csv、parquet等。数据格式需要与选择的损失函数相匹配,例如,如果使用三元组损失函数,则数据集的格式需要包含['anchor', 'positive', 'negative']三列,且顺序不能颠倒。 常见的损失函数包括SoftmaxLoss、MultipleNegativesRankingLoss、ContrastiveLoss、CoSENTLoss、AnglELoss、CosineSimilarityLoss、TripletLoss等,它们适用于不同的数据格式和任务。 为了方便数据加载,可以使用`load_dataset`函数,例如: python from datasets import load_dataset csv_dataset = load_dataset("csv", data_files="my_file.csv") json_dataset = load_dataset("json", data_files="my_file.json") parquet_dataset = load_dataset("parquet", data_files="my_file.parquet")

🚀 **训练参数配置**: 训练参数可以用来优化模型训练效果,并显示训练过程中的进度和参数信息,便于调试。 影响训练效果的参数包括:`learning_rate`、`lr_scheduler_type`、`warmup_ratio`、`num_train_epochs`、`max_steps`、`per_device_train_batch_size`、`per_device_evak_batch_size`、`auto_find_batch_size`、`fp16`、`bf16`、`gradient_accumulation_steps`、`gradient_checkpointing`、`eval_accmulation_steps`、`optim`、`batch_sampler`、`multi_dataset_batch_sampler`等。 观察训练过程的参数包括:`eval_strategy`、`eval_steps`、`save_strategy`、`save_steps`、`save_total_limit`、`load_best_model_at_end`、`report_tolog_eval`、`log_eval`、`logging_steps`、`push_to_hub`、`hub_model_id`、`hub_strategy`、`hub_private_repo`等。

📊 **评估器选择**: 评估器用于评估模型训练过程中的损失。选择合适的评估器需要根据数据格式进行判断。 常见的评估器包括:`BinaryClassificationEvaluator`、`EmbeddingSimilarityEvaluator`、`InformationRetrievalEvaluator`、`MSEEvaluator`、`ParaphraseMiningEvaluator`、`RerankingEvaluator`、`TranslationEvaluator`、`TripletEvaluator`等,它们适用于不同的数据格式和任务。

🧰 **训练器构建**: Sentence Transformer V3提供了一个名为`SentenceTransformerTrainer`的训练器,它可以将模型、训练数据、损失函数、训练参数和评估器整合在一起进行模型训练。 以下是构建`SentenceTransformerTrainer`的示例代码: python from datasets import load_dataset from sentence_transformers import (SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments, SentenceTransformerModelCardData) from sentence_transformers.losses import MultipleNegativesRankingLoss from sentence_transformers.training_args import BatchSamplers from sentence_transformers.evaluation import TripletEvaluator # 1. 加载模型 model = SentenceTransformer("microsoft/mpnet-base", model_card_data=SentenceTransformerModelCardData(language="en", license="apache-2.0", model_name="MPNet base trained on AllNLI triplets")) # 2. 加载数据集 dataset = load_dataset("sentence-transformers/all-nli", "triplet") train_dataset = dataset["train"].select(range(100_000)) eval_dataset = dataset["dev"] test_dataset = dataset["test"] # 3. 定义损失函数 loss = MultipleNegativesRankingLoss(model) # 4. 配置训练参数 args = SentenceTransformerTrainingArguments( output_dir="models/mpnet-base-all-nli-triplet", num_train_epochs=1, per_device_train_batch_size=16, per_device_eval_batch_size=16, learning_rate=2e-5, warmup_ratio=0.1, fp16=True, # Set to False if you get an error that your GPU can't run on FP16 bf16=False, # Set to True if you have a GPU that supports BF16 batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch eval_strategy="steps", eval_steps=100, save_strategy="steps", save_steps=100, save_total_limit=2, logging_steps=100, run_name="mpnet-base-all-nli-triplet" ) # 5. 创建评估器 dev_evaluator = TripletEvaluator(anchors=eval_dataset["anchor"], positives=eval_dataset["positive"], negatives=eval_dataset["negative"], name="all-nli-dev") # 6. 创建训练器 trainer = SentenceTransformerTrainer(model=model, args=args, train_dataset=train_dataset, eval_dataset=eval_dataset, loss=loss, evaluator=dev_evaluator) # 7. 开始训练 trainer.train() # 8. 评估训练后的模型 test_evaluator = TripletEvaluator(anchors=test_dataset["anchor"], positives=test_dataset["positive"], negatives=test_dataset["negative"], name="all-nli-test") test_evaluator(model) # 9. 保存训练后的模型 model.save_pretrained("models/mpnet-base-all-nli-triplet/final")

Sentence Transformer库升级到了V3,其中对模型训练部分做了优化,使得模型训练和微调更加简单了,跟着官方教程走了一遍,顺利完成向量模型的微调,以下是对官方教程的精炼和总结。

一 所需组件

使用Sentence Transformer库进行向量模型的微调需要如下的组件:

    数据数据: 用于训练和评估的数据。损失函数 : 一个量化模型性能并指导优化过程的函数。训练参数 (可选): 影响训练性能和跟踪/调试的参数。评估器 (可选): 一个在训练前、中或后评估模型的工具。训练器 : 将模型、数据集、损失函数和其他组件整合在一起进行训练。

二 数据集

大部分微调用到的数据都是本地的数据集,因此这里只提供本地数据的处理方法。如用其他在线数据可参考相对应的API。

1 数据类型

常见的数据类型为json、csv、parquet,可以使用load_dataset进行加载:

from datasets import load_datasetcsv_dataset = load_dataset("csv", data_files="my_file.csv")json_dataset = load_dataset("json", data_files="my_file.json")parquet_dataset = load_dataset("parquet", data_files="my_file.parquet")

2 数据格式

数据格式需要与损失函数相匹配。如果损失函数需要计算三元组,则数据集的格式为['anchor', 'positive', 'negative'],且顺序不能颠倒。如果损失函数计算的是句子对的相似度或者标签类别,则数据集中需要包含['label']或者['score'],其余列都会作为损失函数的输入。常见的数据格式和损失函数选择见表1。

三 损失函数

链接整理了一些常见的数据格式和匹配的损失函数

InputsLabelsAppropriate Loss Functions
(sentence_A, sentence_B) pairsclassSoftmaxLoss
(anchor, positive) pairsnoneMultipleNegativesRankingLoss
(anchor, positive/negative) pairs1 if positive, 0 if negativeContrastiveLoss / OnlineContrastiveLoss
(sentence_A, sentence_B) pairsfloat similarity scoreCoSENTLoss / AnglELoss / CosineSimilarityLoss
(anchor, positive, negative) tripletsnoneMultipleNegativesRankingLoss / TripletLoss

表1 常见的数据格式和损失函数

四 训练参数

配置训练参数主要是用于提升模型的训练效果,同时可以显示训练过程的进度或者其他参数信息,方便调试。

1 影响训练效果的参数

learning_ratelr_scheduler_typewarmup_rationum_train_epochs
max_stepsper_device_train_batch_sizeper_device_evak_batch_sizeauto_find_batch_size
fp16bf16gradient_accumulation_stepsgradient_checkpointing
eval_accmulation_stepsoptimbatch_samplermulti_dataset_batch_sampler

2 观察训练过程的参数

eval_strategyeval_stepssave_strategysave_steps
save_total_limitload_best_model_at_endreport_tolog_evallog_eval
logging_stepspush_to_hubhub_model_idhub_strategy
hub_private_repo

五 评估器

评估器用于评估模型训练过程中的损失。同损失函数的选择一样,它也需要与数据格式相匹配,以下是评估器的选择依据。

EvaluatorRequired Data
BinaryClassificationEvaluatorPairs with class labels
EmbeddingSimilarityEvaluatorPairs with similarity scores
InformationRetrievalEvaluatorQueries (qid => question), Corpus (cid => document), and relevant documents (qid => set[cid])
MSEEvaluatorSource sentences to embed with a teacher model and target sentences to embed with the student model. Can be the same texts.
ParaphraseMiningEvaluatorMapping of IDs to sentences & pairs with IDs of duplicate sentences.
RerankingEvaluatorList of {'query': '...', 'positive': [...], 'negative': [...]} dictionaries.
TranslationEvaluatorPairs of sentences in two separate languages.
TripletEvaluator(anchor, positive, negative) pairs.

六 训练器

训练器的作用是把先前的组件组合在一起使用。我们仅需要指定模型、训练数据、损失函数、训练参数(可选)、评估器(可选),就可以开始模型的训练。

from datasets import load_datasetfrom sentence_transformers import (SentenceTransformer,SentenceTransformerTrainer,SentenceTransformerTrainingArguments,SentenceTransformerModelCardData,)from sentence_transformers.losses import MultipleNegativesRankingLossfrom sentence_transformers.training_args import BatchSamplersfrom sentence_transformers.evaluation import TripletEvaluator# 1. Load a model to finetune with 2. (Optional) model card datamodel = SentenceTransformer(    "microsoft/mpnet-base",    model_card_data=SentenceTransformerModelCardData(        language="en",        license="apache-2.0",        model_name="MPNet base trained on AllNLI triplets",    ))# 3. Load a dataset to finetune ondataset = load_dataset("sentence-transformers/all-nli", "triplet")train_dataset = dataset["train"].select(range(100_000))eval_dataset = dataset["dev"]test_dataset = dataset["test"]# 4. Define a loss functionloss = MultipleNegativesRankingLoss(model)# 5. (Optional) Specify training argumentsargs = SentenceTransformerTrainingArguments(    # Required parameter:    output_dir="models/mpnet-base-all-nli-triplet",    # Optional training parameters:    num_train_epochs=1,    per_device_train_batch_size=16,    per_device_eval_batch_size=16,    learning_rate=2e-5,    warmup_ratio=0.1,    fp16=True,  # Set to False if you get an error that your GPU can't run on FP16    bf16=False,  # Set to True if you have a GPU that supports BF16    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch    # Optional tracking/debugging parameters:    eval_strategy="steps",    eval_steps=100,    save_strategy="steps",    save_steps=100,    save_total_limit=2,    logging_steps=100,    run_name="mpnet-base-all-nli-triplet",  # Will be used in W&B if `wandb` is installed)# 6. (Optional) Create an evaluator & evaluate the base modeldev_evaluator = TripletEvaluator(    anchors=eval_dataset["anchor"],    positives=eval_dataset["positive"],    negatives=eval_dataset["negative"],    name="all-nli-dev",)dev_evaluator(model)# 7. Create a trainer & traintrainer = SentenceTransformerTrainer(    model=model,    args=args,    train_dataset=train_dataset,    eval_dataset=eval_dataset,    loss=loss,    evaluator=dev_evaluator,)trainer.train()# (Optional) Evaluate the trained model on the test settest_evaluator = TripletEvaluator(    anchors=test_dataset["anchor"],    positives=test_dataset["positive"],    negatives=test_dataset["negative"],    name="all-nli-test",)test_evaluator(model)# 8. Save the trained modelmodel.save_pretrained("models/mpnet-base-all-nli-triplet/final")

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

Sentence Transformer 模型微调 V3 自然语言处理 深度学习
相关文章