掘金 人工智能 前天 14:39
大厂特邀大咖万字深度穿透:Transformer核心模块实现细节大揭秘​
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

本文深入浅出地介绍了Transformer模型的核心——Self-Attention机制。文章首先阐述了传统序列模型的局限性,然后详细解读了Self-Attention的核心思想,包括Query、Key和Value的概念,以及Scaled Dot-Product Attention的原理。接着,文章介绍了Multi-Head Attention的实现细节,并通过代码和可视化,帮助读者更好地理解。此外,文章还探讨了Transformer中不同类型的注意力应用,以及高级应用,例如注意力变体和高效注意力实现。最后,文章给出了TransformerBlock的完整实现,帮助读者构建自己的Transformer模型。

💡Self-Attention机制是Transformer的核心,它通过Query、Key和Value三个向量来捕捉序列中不同位置之间的依赖关系。Query是当前关注的词向量,Key用于被查询,Value传递信息。

✨Scaled Dot-Product Attention是Self-Attention的核心计算模块。它通过计算Query和Key的点积,并进行缩放和Softmax归一化,最终得到注意力权重,再与Value向量加权求和,从而实现对序列信息的有效处理。

🌟Multi-Head Attention通过多个注意力头并行计算,使模型能够从不同的角度理解输入序列,增强模型的表达能力。每个头学习不同的特征子空间,最终将所有头的输出拼接起来进行处理。

🚀Transformer模型中,Self-Attention被应用于编码器和解码器,实现了对输入序列的全面建模。编码器使用自注意力捕捉输入序列内部关系,解码器使用自注意力捕捉目标序列内部关系,以及编码器-解码器注意力处理输入和输出序列之间的关系。

本文较长,建议点赞收藏,以免遗失。更多AI大模型应用开发学习视频及资料,尽在聚客AI学院

本文深入剖析Transformer的核心创新——Self-Attention机制,通过数学推导、代码实现和可视化,全面讲解Query/Key/Value概念、Scaled Dot-Product Attention原理以及Multi-Head Attention实现细节。

一、Self-Attention机制:序列建模的革命

1.1 传统序列建模的局限性

graph LR    A[RNN/LSTM] --> B[顺序处理]    B --> C[无法并行]    C --> D[长程依赖衰减]    D --> E[梯度消失/爆炸]

1.2 Self-Attention核心思想

import torchimport torch.nn as nnimport numpy as npimport matplotlib.pyplot as plt# 输入序列 (batch_size=1, seq_length=4, embedding_dim=8)x = torch.tensor([[    [1.00.50.82.00.11.50.31.2],    [0.71.20.41.80.90.61.10.2],    [1.30.31.70.61.40.80.51.9],    [0.21.51.10.70.31.81.60.4]]])print("输入序列形状:", x.shape)

Self-Attention三大核心向量:

class SelfAttention(nn.Module):    def __init__(self, embed_size):        super().__init__()        self.embed_size = embed_size        # 线性变换层        self.Wq = nn.Linear(embed_size, embed_size)        self.Wk = nn.Linear(embed_size, embed_size)        self.Wv = nn.Linear(embed_size, embed_size)        def forward(self, x):        Q = self.Wq(x)  # Query        K = self.Wk(x)  # Key        V = self.Wv(x)  # Value        return Q, K, V# 生成Q,K,Vattention = SelfAttention(embed_size=8)Q, K, V = attention(x)print("Query形状:", Q.shape)print("Key形状:", K.shape)print("Value形状:", V.shape)

Self-Attention核心优势:

全局依赖:直接捕获任意位置间的关系

并行计算:所有位置同时计算注意力

长程建模:无距离衰减的信息传递

可解释性:注意力权重可视化决策依据

二、Scaled Dot-Product Attention:注意力计算核心

2.1 数学原理详解

计算步骤分解:

相似度计算:QKTQK^T(查询与键的点积)

缩放处理:除以dk\sqrt{d_k}(防止梯度消失)

权重归一化:softmax函数

加权求和:乘以Value向量

def scaled_dot_product_attention(Q, K, V):    # Step 1: 计算Q和K的点积    matmul_qk = torch.matmul(Q, K.transpose(-2, -1))        # Step 2: 缩放处理    d_k = K.size(-1)    scaled_attention_logits = matmul_qk / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))        # Step 3: softmax归一化    attention_weights = torch.softmax(scaled_attention_logits, dim=-1)        # Step 4: 加权求和    output = torch.matmul(attention_weights, V)        return output, attention_weights# 计算注意力output, attn_weights = scaled_dot_product_attention(Q, K, V)print("注意力输出形状:", output.shape)print("注意力权重形状:", attn_weights.shape)

2.2 注意力权重可视化

# 可视化注意力权重plt.figure(figsize=(108))plt.imshow(attn_weights.detach().squeeze().numpy(), cmap='viridis')plt.title('Self-Attention权重矩阵')plt.xlabel('Key位置')plt.ylabel('Query位置')plt.colorbar()plt.xticks(range(4), ['词1''词2''词3''词4'])plt.yticks(range(4), ['词1''词2''词3''词4'])# 添加权重值for i in range(attn_weights.shape[-2]):    for j in range(attn_weights.shape[-1]):        plt.text(j, i, f"{attn_weights[0,i,j].item():.2f}",                  ha="center", va="center", color="w")plt.show()

缩放因子dk\sqrt{d_k}的数学意义:
假设qqkk是独立随机变量,均值为0,方差为1
则点积qk=i=1dkqikiq \cdot k = \sum_{i=1}^{d_k} q_i k_i的:

缩放后方差变为1,保持梯度稳定性:

三、Multi-Head Attention:多视角注意力

3.1 多头注意力原理

graph LR    A[输入向量] --> B[线性变换]    B --> C1[头1 QKV]    B --> C2[头2 QKV]    B --> C3[头n QKV]    C1 --> D1[Scaled Dot-Attention]    C2 --> D2[Scaled Dot-Attention]    C3 --> Dn[Scaled Dot-Attention]    D1 --> E[拼接输出]    D2 --> E    Dn --> E    E --> F[线性变换]    F --> G[最终输出]

3.2 完整多头注意力实现

class MultiHeadAttention(nn.Module):    def __init__(self, embed_size, num_heads):        super().__init__()        self.embed_size = embed_size        self.num_heads = num_heads        self.head_dim = embed_size // num_heads                assert self.head_dim * num_heads == embed_size, "嵌入维度必须是头数的整数倍"                # 线性变换层        self.Wq = nn.Linear(embed_size, embed_size)        self.Wk = nn.Linear(embed_size, embed_size)        self.Wv = nn.Linear(embed_size, embed_size)        self.fc_out = nn.Linear(embed_size, embed_size)        def split_heads(self, x):        """将嵌入维度分割为多个头"""        batch_size, seq_length, _ = x.size()        return x.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)        def forward(self, Q, K, V, mask=None):        batch_size = Q.size(0)                # 线性变换        Q = self.Wq(Q)        K = self.Wk(K)        V = self.Wv(V)                # 分割多头        Q = self.split_heads(Q)  # (batch_size, num_heads, seq_len, head_dim)        K = self.split_heads(K)        V = self.split_heads(V)                # 计算缩放点积注意力        attn_output, attn_weights = scaled_dot_product_attention(Q, K, V)                # 拼接多头输出        attn_output = attn_output.transpose(12).contiguous().view(            batch_size, -1, self.embed_size)                # 最终线性变换        output = self.fc_out(attn_output)                return output, attn_weights# 测试多头注意力embed_size = 8num_heads = 2multihead_attn = MultiHeadAttention(embed_size, num_heads)output, attn_weights = multihead_attn(x, x, x)print("多头注意力输出形状:", output.shape)print("多头注意力权重形状:", attn_weights.shape)  # (batch_size, num_heads, seq_len, seq_len)

3.3 多头注意力可视化

# 可视化不同头的注意力权重fig, axes = plt.subplots(1, num_heads, figsize=(155))for i in range(num_heads):    ax = axes[i]    head_weights = attn_weights[0, i].detach().numpy()    im = ax.imshow(head_weights, cmap='viridis')    ax.set_title(f'头 {i+1} 注意力权重')    ax.set_xlabel('Key位置')    ax.set_ylabel('Query位置')    fig.colorbar(im, ax=ax)        # 添加权重值    for row in range(head_weights.shape[0]):        for col in range(head_weights.shape[1]):            ax.text(col, row, f"{head_weights[row, col]:.2f}",                     ha="center", va="center", color="w", fontsize=8)plt.tight_layout()plt.show()

多头注意力的优势:

多视角建模:每个头关注不同特征空间

并行计算:多个头可同时独立计算

表征能力增强:组合不同子空间信息

可解释性提升:不同头可学习不同关系

四、Transformer中的注意力应用

4.1 编码器-解码器注意力

4.2 三种注意力模式

编码器自注意力:源序列内部关系

encoder_self_attn = MultiHeadAttention(embed_size, num_heads)encoder_output, _ = encoder_self_attn(src, src, src)

解码器自注意力:目标序列内部关系(带掩码)

# 创建下三角掩码def create_mask(size):    mask = torch.tril(torch.ones(size, size))    return mask.masked_fill(mask == 0, float('-inf'))mask = create_mask(tgt.size(1))decoder_self_attn = MultiHeadAttention(embed_size, num_heads)decoder_output, _ = decoder_self_attn(tgt, tgt, tgt, mask)

编码器-解码器注意力:源序列与目标序列间关系

cross_attn = MultiHeadAttention(embed_size, num_heads)cross_output, _ = cross_attn(decoder_output, encoder_output, encoder_output)

4.3 完整Transformer层实现

class TransformerBlock(nn.Module):    """完整的Transformer编码器层"""    def __init__(self, embed_size, num_heads, ff_dim, dropout=0.1):        super().__init__()        # 多头注意力        self.attention = MultiHeadAttention(embed_size, num_heads)        # 前馈网络        self.feed_forward = nn.Sequential(            nn.Linear(embed_size, ff_dim),            nn.ReLU(),            nn.Linear(ff_dim, embed_size)        )        # 归一化层        self.norm1 = nn.LayerNorm(embed_size)        self.norm2 = nn.LayerNorm(embed_size)        # Dropout        self.dropout = nn.Dropout(dropout)        def forward(self, x):        # 残差连接1        attn_output, _ = self.attention(x, x, x)        x = self.norm1(x + self.dropout(attn_output))                # 残差连接2        ff_output = self.feed_forward(x)        x = self.norm2(x + self.dropout(ff_output))                return x# 测试Transformer层transformer_block = TransformerBlock(    embed_size=8,     num_heads=2,     ff_dim=32)output = transformer_block(x)print("Transformer层输出形状:", output.shape)

五、注意力机制高级应用

5.1 注意力变体比较

5.2 自注意力与卷积的融合

class ConvAttention(nn.Module):    """卷积增强的自注意力"""    def __init__(self, embed_size, num_heads, kernel_size=3):        super().__init__()        self.attention = MultiHeadAttention(embed_size, num_heads)        self.conv = nn.Conv1d(            in_channels=embed_size,            out_channels=embed_size,            kernel_size=kernel_size,            padding=kernel_size//2        )        self.norm = nn.LayerNorm(embed_size)        def forward(self, x):        # 自注意力路径        attn_out, _ = self.attention(x, x, x)                # 卷积路径 (需要调整维度)        conv_out = self.conv(x.transpose(12)).transpose(12)                # 融合并归一化        combined = attn_out + conv_out        return self.norm(combined)# 测试卷积注意力conv_attn = ConvAttention(embed_size=8, num_heads=2)output = conv_attn(x)print("卷积注意力输出形状:", output.shape)

5.3 高效注意力实现

class EfficientAttention(nn.Module):    """线性复杂度的注意力机制"""    def __init__(self, embed_size):        super().__init__()        self.embed_size = embed_size        # 特征变换        self.to_query = nn.Linear(embed_size, embed_size)        self.to_key = nn.Linear(embed_size, embed_size)        self.to_value = nn.Linear(embed_size, embed_size)            def forward(self, x):        Q = self.to_query(x)        K = self.to_key(x)        V = self.to_value(x)                # 高效计算 (避免显式计算QK^T)        K = K.softmax(dim=1)        context = torch.einsum('bnd,bne->bde', K, V)        output = torch.einsum('bnd,bde->bne', Q, context)                return output# 测试高效注意力eff_attn = EfficientAttention(embed_size=8)output = eff_attn(x)print("高效注意力输出形状:", output.shape)

六、Self-Attention实战:文本分类

6.1 数据准备

from torchtext.datasets import IMDBfrom torchtext.data import get_tokenizerfrom torchtext.vocab import build_vocab_from_iterator# 加载IMDB数据集train_iter = IMDB(split='train')tokenizer = get_tokenizer('basic_english')# 构建词汇表def yield_tokens(data_iter):    for _, text in data_iter:        yield tokenizer(text)vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=['<unk>''<pad>'])vocab.set_default_index(vocab['<unk>'])# 文本转张量def text_pipeline(text):    return vocab(tokenizer(text))# 创建批次处理函数def collate_batch(batch, max_len=512):    label_list, text_list = [], []    for label, text in batch:        label_list.append(1 if label=='pos' else 0)        processed_text = text_pipeline(text)[:max_len]        processed_text += [vocab['<pad>']] * (max_len - len(processed_text))        text_list.append(processed_text)    return torch.tensor(label_list), torch.tensor(text_list)# 创建数据加载器from torch.utils.data import DataLoadertrain_loader = DataLoader(    list(IMDB(split='train')),     batch_size=32,     collate_fn=collate_batch)

6.2 基于Self-Attention的分类模型

class AttentionClassifier(nn.Module):    def __init__(self, vocab_size, embed_dim, num_heads, hidden_dim, num_classes):        super().__init__()        # 嵌入层        self.embedding = nn.Embedding(vocab_size, embed_dim)        # 位置编码        self.pos_encoding = nn.Parameter(torch.randn(1512, embed_dim))        # 自注意力层        self.attention = MultiHeadAttention(embed_dim, num_heads)        # 分类器        self.classifier = nn.Sequential(            nn.Linear(embed_dim, hidden_dim),            nn.ReLU(),            nn.Dropout(0.5),            nn.Linear(hidden_dim, num_classes)        )        def forward(self, x):        # 嵌入层        x = self.embedding(x)  # (batch, seq, embed_dim)                # 添加位置编码        seq_len = x.size(1)        x = x + self.pos_encoding[:, :seq_len, :]                # 自注意力        attn_output, _ = self.attention(x, x, x)                # 全局平均池化        pooled = attn_output.mean(dim=1)                # 分类        return self.classifier(pooled)# 初始化模型vocab_size = len(vocab)model = AttentionClassifier(    vocab_size=vocab_size,    embed_dim=128,    num_heads=4,    hidden_dim=256,    num_classes=2)# 训练配置criterion = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

6.3 训练与注意力可视化

# 训练循环for epoch in range(5):    total_loss = 0    correct = 0    total = 0        for labels, texts in train_loader:        optimizer.zero_grad()        outputs = model(texts)        loss = criterion(outputs, labels)        loss.backward()        optimizer.step()                total_loss += loss.item()        _, predicted = outputs.max(1)        total += labels.size(0)        correct += predicted.eq(labels).sum().item()        accuracy = 100. * correct / total    print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}, Acc: {accuracy:.2f}%")# 可视化样本注意力def visualize_attention(text):    # 预处理文本    tokens = tokenizer(text)    indexed = [vocab[token] for token in tokens][:512]    input_tensor = torch.tensor([indexed])        # 获取模型输出和注意力权重    model.eval()    with torch.no_grad():        embeddings = model.embedding(input_tensor)        _, attn_weights = model.attention(embeddings, embeddings, embeddings)        attn_weights = attn_weights.mean(dim=1)  # 平均多头        # 可视化    plt.figure(figsize=(126))    plt.imshow(attn_weights.squeeze().numpy(), cmap='viridis')    plt.title('文本注意力权重')    plt.xlabel('Key位置')    plt.ylabel('Query位置')    plt.xticks(range(len(tokens)), tokens, rotation=90)    plt.yticks(range(len(tokens)), tokens)    plt.colorbar()    plt.tight_layout()    plt.show()# 测试样例sample_text = "This movie is absolutely fantastic and captivating from start to finish"visualize_attention(sample_text)

关键要点总结

Self-Attention核心公式:

Multi-Head Attention处理流程:

flowchart LR    A[输入] --> B[线性变换]    B --> C[分割多头]    C --> D[Scaled Dot-Product]    D --> E[拼接输出]    E --> F[线性变换]    F --> G[输出]

Transformer中注意力的三种应用:

注意力机制超参数选择:

掌握Self-Attention机制是理解现代大模型的基础,通过本文的数学推导和代码实践,你已经具备了实现和优化注意力模型的核心能力!更多AI大模型应用开发学习视频内容和资料,尽在聚客AI学院

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

Transformer Self-Attention Multi-Head Attention 深度学习 自然语言处理
相关文章