截止到现在,我们已经完成了LLM的整体架构搭建,是时候使用它来做一些下游的任务了。
我们所构建的LLM是GPT2,官方开源了它的预训练权重。如果只是使用GPT2实现文本续写等功能,可以直接加载预训练模型并进行推理。
然而,在实际的任务中,往往需要使用领域数据对LLM进行微调,以适配特定的下游任务,比如垃圾短信分类、对话生成、情感分析等。
本文使用一个垃圾邮件分类的任务,来说明如何基于预训练的GPT2在邮件数据集上进行微调,我们的目标是打造一个垃圾邮件分类器,输入一份邮件的内容,模型给出该邮件是否为垃圾邮件的分类结果。
一、准备垃圾邮件分类数据集
在https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip
下载垃圾邮件分类数据集,如下:
每一行是一个样本,其中第一列是label,空格后面的是label对应的邮件内容,label总共有两个值,表示这封邮件是否为垃圾邮件。
为乐便于后续数据预处理,将其读取为pandas的数据框格式:
data_file_path = "sms_spam_collection/SMSSpamCollection.tsv"df = pd.read_csv(data_file_path, sep="\t", header=None, names=["Label", "Text"])
通过分析可知,两个类别不平衡,这里我们选择对数量多的类别做子采样操作,这样两个类别的样本量就相等了:
def create_balanced_dataset(df): # Count the instances of "spam" num_spam = df[df["Label"] == "spam"].shape[0] # Randomly sample "ham" instances to match the number of "spam" instances ham_subset = df[df["Label"] == "ham"].sample(num_spam, random_state=123) # Combine ham "subset" with "spam" balanced_df = pd.concat([ham_subset, df[df["Label"] == "spam"]]) return balanced_dfbalanced_df = create_balanced_dataset(df)print(balanced_df["Label"].value_counts())# ham:747, spam:747
对label进行数值编码,转换为计算机可以理解的数字:
balanced_df["Label"] = balanced_df["Label"].map({"ham": 0, "spam": 1})
现在来划分训练集、验证集和测试集,并保存到单独的本地文件:
def random_split(df, train_frac, validation_frac): # Shuffle the entire DataFrame df = df.sample(frac=1, random_state=123).reset_index(drop=True) # Calculate split indices train_end = int(len(df) * train_frac) validation_end = train_end + int(len(df) * validation_frac) # Split the DataFrame train_df = df[:train_end] validation_df = df[train_end:validation_end] test_df = df[validation_end:] return train_df, validation_df, test_dftrain_df, validation_df, test_df = random_split(balanced_df, 0.7, 0.1)train_df.to_csv("train.csv", index=None)validation_df.to_csv("validation.csv", index=None)test_df.to_csv("test.csv", index=None)
至此,数据已经预处理好了,现在来着手构建数据加载器。
由于每封邮件内容的长度不一,我们需要统一长度。一种做法是将全部样本内容截断到与最短的那封邮件内容一样的长度,但是这样可能存在长邮件的信息丢失问题。另一种做法是将所有邮件内容都填充到与最长邮件内容一致的长度,这里我们选择第二种。
为了不引入额外的偏差,采用“结束符”作为填充内容,它是GPT2的token ID=50256对应的“词”。
import torchfrom torch.utils.data import Datasetclass SpamDataset(Dataset): def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256): self.data = pd.read_csv(csv_file) # Pre-tokenize texts self.encoded_texts = [ tokenizer.encode(text) for text in self.data["Text"] ] if max_length isNone: self.max_length = self._longest_encoded_length() else: self.max_length = max_length # Truncate sequences if they are longer than max_length self.encoded_texts = [ encoded_text[:self.max_length] for encoded_text in self.encoded_texts ] # Pad sequences to the longest sequence self.encoded_texts = [ encoded_text + [pad_token_id] * (self.max_length - len(encoded_text)) for encoded_text in self.encoded_texts ] def __getitem__(self, index): encoded = self.encoded_texts[index] label = self.data.iloc[index]["Label"] return ( torch.tensor(encoded, dtype=torch.long), torch.tensor(label, dtype=torch.long) ) def __len__(self): return len(self.data) def _longest_encoded_length(self): max_length = 0 for encoded_text in self.encoded_texts: encoded_length = len(encoded_text) if encoded_length > max_length: max_length = encoded_length return max_length
调用上面构建的SpamDataset
,创建数据加载器:
train_dataset = SpamDataset( csv_file="train.csv", max_length=None, tokenizer=tokenizer)print(train_dataset.max_length)# 120val_dataset = SpamDataset( csv_file="validation.csv", max_length=train_dataset.max_length, tokenizer=tokenizer)test_dataset = SpamDataset( csv_file="test.csv", max_length=train_dataset.max_length, tokenizer=tokenizer)print(test_dataset.max_length)#120from torch.utils.data import DataLoadernum_workers = 0batch_size = 8torch.manual_seed(123)train_loader = DataLoader( dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True,)val_loader = DataLoader( dataset=val_dataset, batch_size=batch_size, num_workers=num_workers, drop_last=False,)test_loader = DataLoader( dataset=test_dataset, batch_size=batch_size, num_workers=num_workers, drop_last=False,)
检查一下数据集的格式:
print("Train loader:")for input_batch, target_batch in train_loader: print("Input batch dimensions:", input_batch.shape)# [8,120] print("Label batch dimensions", target_batch.shape)# [8] break
ok没问题,现在数据已经准备好了。
二、修改GPT2架构以适配垃圾邮件分类任务
首先把我们之前搭建好的GPT2架构定义好:
BASE_CONFIG = { "vocab_size": 50257, # Vocabulary size "context_length": 1024, # Context length "drop_rate": 0.0, # Dropout rate "qkv_bias": True # Query-key-value bias}CHOOSE_MODEL = "gpt2-small (124M)"model_configs = { "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12}, "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16}, "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20}, "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},}BASE_CONFIG.update(model_configs[CHOOSE_MODEL])model_size = CHOOSE_MODEL.split(" ")[-1].lstrip("(").rstrip(")")model = GPTModel(BASE_CONFIG)
然后加载一下官方的预训练权重:
model.load_state_dict(torch.load('./load_pretrained_GPT2/gpt2.pth'))
我们的目标是将邮件内容输入到GPT2,希望GPT2给出这封邮件是否是垃圾邮件的分类结果。但是,当前的输出是基于输入的续写内容,因此我们需要改造最后的输出头。方法也很简单,就是把原来输出的vocab_size
维改为2维:
num_classes = 2model.out_head = torch.nn.Linear(in_features=BASE_CONFIG["emb_dim"], out_features=num_classes)
这样,我们就把GPT2改造成了一个适合做二分类的架构了。
三、准备一些工具函数(损失,指标)
计算准确率:
def calc_accuracy_loader(data_loader, model, device, num_batches=None): model.eval() correct_predictions, num_examples = 0, 0 if num_batches isNone: num_batches = len(data_loader) else: num_batches = min(num_batches, len(data_loader)) for i, (input_batch, target_batch) in enumerate(data_loader): if i < num_batches: input_batch, target_batch = input_batch.to(device), target_batch.to(device) with torch.no_grad(): logits = model(input_batch)[:, -1, :] # Logits of last output token predicted_labels = torch.argmax(logits, dim=-1) num_examples += predicted_labels.shape[0] correct_predictions += (predicted_labels == target_batch).sum().item() else: break return correct_predictions / num_examples
计算损失:
def calc_loss_batch(input_batch, target_batch, model, device): input_batch, target_batch = input_batch.to(device), target_batch.to(device) logits = model(input_batch)[:, -1, :] # Logits of last output token loss = torch.nn.functional.cross_entropy(logits, target_batch) return lossdef calc_loss_loader(data_loader, model, device, num_batches=None): total_loss = 0. if len(data_loader) == 0: return float("nan") elif num_batches isNone: num_batches = len(data_loader) else: # Reduce the number of batches to match the total number of batches in the data loader # if num_batches exceeds the number of batches in the data loader num_batches = min(num_batches, len(data_loader)) for i, (input_batch, target_batch) in enumerate(data_loader): if i < num_batches: loss = calc_loss_batch(input_batch, target_batch, model, device) total_loss += loss.item() else: break return total_loss / num_batches
四、开始微调
我们只微调一部分层,因此需要冻结大部分的层:
for param in model.parameters(): param.requires_grad = Falsefor param in model.trf_blocks[-1].parameters(): param.requires_grad = Truefor param in model.final_norm.parameters(): param.requires_grad = True
编写训练函数:
# Overall the same as `train_model_simple` in chapter 5def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs, eval_freq, eval_iter): # Initialize lists to track losses and examples seen train_losses, val_losses, train_accs, val_accs = [], [], [], [] examples_seen, global_step = 0, -1 # Main training loop for epoch in range(num_epochs): model.train() # Set model to training mode for input_batch, target_batch in train_loader: optimizer.zero_grad() # Reset loss gradients from previous batch iteration loss = calc_loss_batch(input_batch, target_batch, model, device) loss.backward() # Calculate loss gradients optimizer.step() # Update model weights using loss gradients examples_seen += input_batch.shape[0] # New: track examples instead of tokens global_step += 1 # Optional evaluation step if global_step % eval_freq == 0: train_loss, val_loss = evaluate_model( model, train_loader, val_loader, device, eval_iter) train_losses.append(train_loss) val_losses.append(val_loss) print(f"Ep {epoch+1} (Step {global_step:06d}): " f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}") # Calculate accuracy after each epoch train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter) val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter) print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="") print(f"Validation accuracy: {val_accuracy*100:.2f}%") train_accs.append(train_accuracy) val_accs.append(val_accuracy) return train_losses, val_losses, train_accs, val_accs, examples_seen
编写评估函数:
def evaluate_model(model, train_loader, val_loader, device, eval_iter): model.eval() with torch.no_grad(): train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter) val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter) model.train() return train_loss, val_loss
开始微调训练:
import timestart_time = time.time()optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.1)num_epochs = 5train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple( model, train_loader, val_loader, optimizer, device, num_epochs=num_epochs, eval_freq=50, eval_iter=5,)end_time = time.time()execution_time_minutes = (end_time - start_time) / 60print(f"Training completed in {execution_time_minutes:.2f} minutes.")
训练日志如下:
Ep 1 (Step 000000): Train loss 2.153, Val loss 2.392Ep 1 (Step 000050): Train loss 0.617, Val loss 0.637Ep 1 (Step 000100): Train loss 0.523, Val loss 0.557Training accuracy: 70.00% | Validation accuracy: 72.50%Ep 2 (Step 000150): Train loss 0.561, Val loss 0.489Ep 2 (Step 000200): Train loss 0.419, Val loss 0.397Ep 2 (Step 000250): Train loss 0.409, Val loss 0.353Training accuracy: 82.50% | Validation accuracy: 85.00%Ep 3 (Step 000300): Train loss 0.333, Val loss 0.320Ep 3 (Step 000350): Train loss 0.340, Val loss 0.306Training accuracy: 90.00% | Validation accuracy: 90.00%Ep 4 (Step 000400): Train loss 0.136, Val loss 0.200Ep 4 (Step 000450): Train loss 0.153, Val loss 0.132Ep 4 (Step 000500): Train loss 0.222, Val loss 0.137Training accuracy: 100.00% | Validation accuracy: 97.50%Ep 5 (Step 000550): Train loss 0.207, Val loss 0.143Ep 5 (Step 000600): Train loss 0.083, Val loss 0.074Training accuracy: 100.00% | Validation accuracy: 97.50%Training completed in 26.73 minutes.
查看每个epoch对应的损失值和准确度:
import matplotlib.pyplot as pltdef plot_values(epochs_seen, examples_seen, train_values, val_values, label="loss"): fig, ax1 = plt.subplots(figsize=(5, 3)) # Plot training and validation loss against epochs ax1.plot(epochs_seen, train_values, label=f"Training {label}") ax1.plot(epochs_seen, val_values, linestyle="-.", label=f"Validation {label}") ax1.set_xlabel("Epochs") ax1.set_ylabel(label.capitalize()) ax1.legend() # Create a second x-axis for examples seen ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis ax2.plot(examples_seen, train_values, alpha=0) # Invisible plot for aligning ticks ax2.set_xlabel("Examples seen") fig.tight_layout() # Adjust layout to make room plt.savefig(f"{label}-plot.pdf") plt.show()epochs_tensor = torch.linspace(0, num_epochs, len(train_losses))examples_seen_tensor = torch.linspace(0, examples_seen, len(train_losses))plot_values(epochs_tensor, examples_seen_tensor, train_losses, val_losses)epochs_tensor = torch.linspace(0, num_epochs, len(train_accs))examples_seen_tensor = torch.linspace(0, examples_seen, len(train_accs))plot_values(epochs_tensor, examples_seen_tensor, train_accs, val_accs, label="accuracy")
五、推理程序封装
训练好垃圾邮件分类器后,将其进行封装,得到一个可被调用的垃圾邮件分类器函数。
def classify_review(text, model, tokenizer, device, max_length=None, pad_token_id=50256): model.eval() input_ids = tokenizer.encode(text) supported_context_length = model.pos_emb.weight.shape[0] input_ids = input_ids[:min(max_length, supported_context_length)] input_ids += [pad_token_id] * (max_length - len(input_ids)) input_tensor = torch.tensor(input_ids, device=device).unsqueeze(0) # add batch dimension with torch.no_grad(): logits = model(input_tensor)[:, -1, :] # Logits of the last output token predicted_label = torch.argmax(logits, dim=-1).item() return"spam"if predicted_label == 1else"not spam"
调用测试:
text_1 = ( "You are a winner you have been specially" " selected to receive $1000 cash or a $2000 award.")print(classify_review( text_1, model, tokenizer, device, max_length=train_dataset.max_length))# spam, 分类结果是垃圾邮件
在下一节,我们介绍指令微调,欢迎持续关注。
原文地址:https://mp.weixin.qq.com/s/P2oThAtGB8cs5PC1Y-b1pw