机器学习初学者 02月11日
【深度学习】5分钟解释Transformer,一个神奇的算法模型 !!
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

本文以直观语言解释Transformer模型,通过多语言聊天室等例子,介绍其编码器、解码器、注意力机制等内容,还包括模型组成、代码示例及训练过程

💻Transformer能一次性处理整个输入,更快更准确理解数据

🧠编码器将词语转换成向量,注意力机制决定词的重要性

📝解码器根据编码器信息和自身知识生成回复

🌐多头注意力机制可并行处理多语言消息

cos大壮 2025-02-11 14:06 浙江

通俗易懂的语言解释transformer模型!

最近很多初学者在聊天当中,提到了Transformer模型。在前几天的发文中也说了几次。

今天还是想以非常直观、简单的语言,让大家有一个大的轮廓,从而更清晰的、更好的学习后面的理论部分。

最初听过一个很简单的例子:想象你在读一篇文章,如果你只能逐字阅读,你会花很长时间才能理解每个句子。但如果你能一次看到整篇文章,你会更快抓住文章的意思。Transformer 的能力就像这样,能一次性处理整个输入,从而更快、更准确地理解数据。

这里,我们再用一个非常直观地例子给大家~

假设你加入了一个多语言的在线聊天室,有些人用中文发送消息,有些人用英文。作为一个只懂英语的 Transformer 模型,让我们分步骤来看它如何处理消息:

1. 接收消息和理解(编码器)

每当有人发送一条消息,比如:

作为一个 Transformer 模型:

2. 理解和生成(解码器)

当编码器把消息转换成内部表示后,解码器负责生成回复:

3. 处理多语言

现在假设聊天室中有:

作为 Transformer 模型:

通过编码器和解码器的组合,利用注意力机制和多头注意力机制来有效地理解和生成文本数据,无论消息是中文还是英文,都能够得到适当的处理和回复。

Transformer 特别擅长处理序列数据,如自然语言文本。最初由 Google 提出的 Transformer 被用来处理文本翻译任务,现在它在多种任务中表现优异,包括文本生成、分类和信息提取等。和传统的序列模型(如 RNN)不同,Transformer 通过并行处理整个输入序列,大大提高了处理速度和效率。

基本原理

Transformer 的核心组件是 注意力机制 (Attention Mechanism),它允许模型在处理每个元素时,同时参考输入序列中的所有其他元素。

Transformer 主要由两个部分组成:编码器(Encoder)和 解码器(Decoder)。

编码器

解码器

注意力机制

注意力机制的目标是根据输入的每个单词生成一个权重,表示该单词对当前任务的重要性。计算这些权重的过程称为点积注意力 (Scaled Dot-Product Attention),其公式如下:

其中:

多头注意力机制

为了捕捉不同位置之间的关系,Transformer 使用 多头注意力机制 (Multi-Head Attention),它将查询、键、值矩阵拆分为多组,然后独立地应用注意力机制,最后将这些结果合并。

其中:

完整案例

下面我们通过一个简单的 Python 代码示例,演示如何使用 Transformer 进行文本分类。

代码中,使用 PyTorch 和其 Transformer 模块。

准备数据

我们使用一个示例数据集,其中每个句子被标注为正面或负面。

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
# 示例数据集
class SimpleDataset(Dataset):
    def __init__(self, texts, labels):
        self.texts = texts
        self.labels = labels
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.max_len = 128
    def __len__(self):
        return len(self.texts)
    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )
        return {
            'text': text,
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }
# 创建数据集
texts = ["I love this movie!""This was a terrible film."]
labels = [10]
dataset = SimpleDataset(texts, labels)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

定义模型

我们使用预训练的 BERT 模型进行分类任务。

from transformers import BertModel, BertConfig
import torch.nn as nn
class SimpleTransformerModel(nn.Module):
    def __init__(self, num_labels):
        super(SimpleTransformerModel, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
    
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs[1]  # 取池化后的输出
        logits = self.classifier(pooled_output)
        return logits
# 初始化模型
model = SimpleTransformerModel(num_labels=2)

训练模型

简单的训练过程如下:

import torch.optim as optim
from torch.nn import CrossEntropyLoss
# 损失函数和优化器
criterion = CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=2e-5)
# 训练循环
model.train()
for epoch in range(3):  # 训练 3 个 epoch
    for batch in dataloader:
        optimizer.zero_grad()
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']
        outputs = model(input_ids, attention_mask)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch+1}, Loss: {loss.item()}')

最后

Transformer 目前来说是一种非常重要的架构。它通过注意力机制高效处理序列数据,克服了传统模型的许多限制。

喜欢本文的朋友可以收藏、点赞、转发起来!


阅读原文

跳转微信打开

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

Transformer模型 注意力机制 编码器 解码器 文本处理
相关文章