掘金 人工智能 21小时前
【手搓大模型】从零手写GPT2 — Attention
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

本文深入浅出地介绍了Attention机制,一种在深度学习中用于捕捉token之间依赖关系的关键技术。文章首先解释了Attention的核心思想,即通过Query、Key、Value的交互模拟“查询-检索-提取”的过程。接着,详细阐述了Scaled Dot-Product Attention的计算流程,包括Q、K、V的矩阵变换、相似度计算、缩放、Softmax归一化以及加权求和。此外,文章还讨论了Causal Attention,Dropout等技术,并给出了多头注意力机制的实现方法,帮助读者全面理解Attention机制的原理和应用。

💡Attention机制的核心在于模拟“查询-检索-提取”的过程,通过Query、Key、Value这三个Tensor来捕捉token之间的依赖关系,从而构建上下文。

🔑Scaled Dot-Product Attention是Attention机制的具体实现,它通过计算Query和Key的点积来衡量token之间的相关性,并进行缩放和Softmax操作,最终得到加权的Value。

🛡️Causal Attention通过Masking机制屏蔽未来token,确保模型在生成过程中只能依赖已有的上下文信息,这对于自回归模型至关重要。

🧠Dropout技术在Attention机制中被用于提高模型的泛化能力,它随机丢弃一部分权重,迫使模型学习更鲁棒的特征表示。

➕Multi-Head Attention通过并行多个Attention头,使模型能够关注不同角度的语义信息,从而增强模型的表达能力。

Understand Attention in an Intuitive Way

前文我们讲述了 token Embedding和position Embedding,不过embedding从本质上依然是关于token自身的向量。一旦训练结束,embedding就是固定的weights了。我们需要引入一种新的机制,去关注token与token之间的依赖关系,也就是token所处的context。

如下示例:

“He sat by the bank of the river.”

"The cat that the dog chased was black."

bank到底表示是银行还是河岸,取决于旁边的river;black虽然离dog更近,但我们依然知道说的其实是cat。

有侧重点地关注句子中不同token的影响,即关注语义相关性,正是Attention机制的出发点。也就是说,Attention机制的本质,是让模型在每一步中自主判断哪些token更相关,并依次构建context。

Attention机制的论文中,创新地引入了Query/Key/Value这3个tensor:

An attention function can be described as mapping a query and a set of key-value pairs to an output,

where the query, keys, values, and output are all vectors. The output is computed as a weighted sum

of the values, where the weight assigned to each value is computed by a compatibility function of the

query with the corresponding key.

其中:Query代表问题或关注点,Key代表信息的索引,Value代表信息的具体值。

直观上理解,Attention 模拟了“查询-检索-提取”的过程,以下粗浅的示例帮助理解:

1)假设你去图书馆找书(query是你想知道的主题),

2)图书馆有很多书架标签(keys)和书(values),

3)你先看每个书架标签和你想要的主题有多相关(compatibility),

4)然后你根据相关度决定从哪些书架取多少书(weighted as compatibility),

5)把拿到的书内容综合起来(weighted sum),就是你最终的答案(output)。

Scaled Dot-Product Attention

Attention Definition

Attention的严格定义如下:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left( \frac{Q K^T}{\sqrt{d_k}} \right) V

其中Q,K,V分别代表query、key、value对应的matrix;而dk代表matrix的维度,用于做scaling。

定义很简洁,其实计算过程也特别清晰简单。

以下用5个token,嵌入维度是4,为例:

import torchimport torch.nn as nntorch.manual_seed(123)tokens = ["Once", "upon", "a", "time", "there"]token_to_idx = {token: idx for idx, token in enumerate(tokens)}embedding_dim = 4embedding_layer = nn.Embedding(num_embeddings=len(tokens), embedding_dim=embedding_dim)input_indices = torch.tensor([token_to_idx[token] for token in tokens])  # [0,1,2,3,4]X = embedding_layer(input_indices)print("shape of input X:", X.shape)print(X)

shape of input X: torch.Size([5, 4])

tensor([[ 0.3374, -0.1778, -0.3035, -0.5880],

[ 1.5810, 1.3010, 1.2753, -0.2010],

[-0.1606, -0.4015, 0.6957, -1.8061],

[-1.1589, 0.3255, -0.6315, -2.8400],

[-0.7849, -1.4096, -0.4076, 0.7953]], grad_fn=)

得到5*4的二维matrix,如下:

Q K V matrix

而依据次,创建出Q/K/V的matrix的方法特别简单,只需要指定维度,并随机初始化得到初始矩阵,并利用该矩阵对输入X做线性映射,如下:

torch.manual_seed(123)W_Q = torch.nn.Parameter(torch.rand(embedding_dim, embedding_dim), requires_grad=False)W_K = torch.nn.Parameter(torch.rand(embedding_dim, embedding_dim), requires_grad=False)W_V = torch.nn.Parameter(torch.rand(embedding_dim, embedding_dim), requires_grad=False)print("shape of W_Q:", W_Q.shape)print("W_Q:", W_Q)Q = X @ W_QK = X @ W_KV = X @ W_Vprint("shape of Q:", Q.shape)print("Q:", Q)

shape of W_Q: torch.Size([4, 4])

W_Q: Parameter containing:

tensor([[0.2961, 0.5166, 0.2517, 0.6886],

[0.0740, 0.8665, 0.1366, 0.1025],

[0.1841, 0.7264, 0.3153, 0.6871],

[0.0756, 0.1966, 0.3164, 0.4017]])

shape of Q: torch.Size([5, 4])

Q: tensor([[-0.0136, -0.3159, -0.2211, -0.2307],

[ 0.7839, 2.8310, 0.9140, 2.0175],

[-0.0858, -0.2806, -0.4474, -0.3992],

[-0.6501, -1.3338, -1.3449, -2.3394],

[-0.3515, -1.7666, -0.2669, -0.6454]], grad_fn=)

请注意X的维度是[5,4],随机初始化得到的W_Q的维度是[4,4],根据矩阵乘法,得到的Q的维度[5,4]。

这里使用了矩阵乘法(@等价与torch.matmul),相当于对原始的输入X做了线性投影。

值得注意的是,W_Q, W_K, W_V这3个都是可训练的参数。这就意味着其初始值并不重要,重要的是搭建出的空间自由度和信息的通路。

Similarity as Scores

依据上述公式,scores=QKT\text{scores} = Q K^T,我们需要计算Q与K之间的点积,以计算二者之间的相关性或相似度。

scores = Q @ K.Tprint("shape of scores:", scores.shape)print("scores:", scores)

shape of scores: torch.Size([5, 5])

scores: tensor([[ 0.3101, -2.0474, 0.7024, 1.8280, 1.0647],

[ -2.5714, 17.4476, -5.5017, -14.6920, -9.3044],

[ 0.6084, -2.9632, 1.4480, 3.1775, 1.4642],

[ 2.8736, -14.6337, 6.4597, 14.7155, 7.4156],

[ 0.9222, -8.1955, 1.8808, 5.9959, 4.5150]],

grad_fn=)

上面例子中,Q与K的维度是[5,4],对K做转置得到维度[4,5],二者做点积,得到[5,5]的矩阵。

注:此处实际上是做的批量点积,即使用的是矩阵乘法。

Scaled Scores

根据上述公式,scaled_scores=QKTdk\text{scaled\_scores} = \frac{Q K^T}{\sqrt{d_k}},我们需要对得到的scores进行scaled,操作如下:

import mathattention_scores = scores / math.sqrt(embedding_dim)print(attention_scores)

tensor([[ 0.1551, -1.0237, 0.3512, 0.9140, 0.5323],

[-1.2857, 8.7238, -2.7508, -7.3460, -4.6522],

[ 0.3042, -1.4816, 0.7240, 1.5888, 0.7321],

[ 1.4368, -7.3169, 3.2298, 7.3577, 3.7078],

[ 0.4611, -4.0977, 0.9404, 2.9979, 2.2575]], grad_fn=)

那为什么要进行缩放,以及为什么要选择上面的值进行缩放呢? 缩放主要是为了压缩score,避免后续softmax输出的分布太过极端,让梯度计算更流畅;之所以选择dk的平方根,应该有其数学统计意义。不过个人感觉还是经验与折中方案,除以其他值也是合理的,不必过分关注,本质上就只是一种数据正则化手段。

至此,我们就完成了Scaled Attention scores的计算。

Compute Attention Weights via Softmax

而从attention scores到可被使用的weights,还需要做进一步的归一化,即通过softmax操作:

attention_weights=softmax(QKTdk)\text{attention\_weights} = \text{softmax} \left( \frac{Q K^T}{\sqrt{d_k}} \right)

画图看一眼softmax函数,特别简单,如下:

import torchimport matplotlib.pyplot as pltx = torch.linspace(-5, 5, 200)scores = torch.stack([x, torch.zeros_like(x)], dim=1)softmax_vals = torch.softmax(scores, dim=1)plt.plot(x.numpy(), softmax_vals[:,0].numpy())plt.show()

可见softmax把所有的输入都压缩到(0,1)之间,使之看起来更像概率值。

注:softmax本质是一种数据归一化,也可以替换成其他相似函数。

我们直接使用pytorch自带的softmax函数计算如下:

attention_weights = torch.softmax(attention_scores, dim=-1)print("shape of attention_weights:", attention_weights.shape)print(attention_weights)

shape of attention_weights: torch.Size([5, 5])

tensor([[1.6344e-01, 5.0283e-02, 1.9885e-01, 3.4910e-01, 2.3833e-01],

[4.4966e-05, 9.9994e-01, 1.0389e-05, 1.0494e-07, 1.5519e-06],

[1.2761e-01, 2.1395e-02, 1.9418e-01, 4.6106e-01, 1.9576e-01],

[2.5676e-03, 4.0538e-07, 1.5426e-02, 9.5713e-01, 2.4878e-02],

[4.6963e-02, 4.9191e-04, 7.5844e-02, 5.9361e-01, 2.8309e-01]],

grad_fn=)

可见得到的weight均在(0,1),很适合用来做加权计算。

Output as weighted sum

依据Attention定义,得到weights矩阵后,需要与Value矩阵相乘,得到最终的Attention ouput:

output=attention_weightsV\text{output} = \text{attention\_weights} \cdot V

# Final output of self-attentionoutput = attention_weights @ Vprint("shape of output:", output.shape)print(output)

shape of output: torch.Size([5, 4])

tensor([[-1.0221, -1.1318, -1.0966, -1.2475],

[ 1.6613, 1.7716, 2.1347, 2.5049],

[-1.3064, -1.3985, -1.3982, -1.5418],

[-2.2928, -2.2490, -2.4211, -2.5138],

[-1.6010, -1.6693, -1.7563, -1.9028]], grad_fn=)

注意维度的变化,[5,5] * [5,4]得到最终output的形状是[5,4],这与输入X的形状刚好是一致的。也就是说,经过了Attention变化之后,输出的维度依然与输入相同。

至此,我们完成了Attention的完整计算。

Simple Self-Attention Code

理解了上述过程,我们可以使用pytorch非常方便地构建self-attention模块,如下:

import torchimport torch.nn as nnimport torch.nn.functional as Fclass SimpleSelfAttention(nn.Module):    def __init__(self, d_in, d_out):        super().__init__()        # (d_in, d_out)        self.W_Q = nn.Linear(d_in, d_out, bias=False)        self.W_K = nn.Linear(d_in, d_out, bias=False)        self.W_V = nn.Linear(d_in, d_out, bias=False)    def forward(self, x):        # (seq_len, d_in) x (d_in, d_out) -> (seq_len, d_out)        Q = self.W_Q(x)  # equal to: x @ W_Q.T        K = self.W_K(x)        V = self.W_V(x)        # (seq_len, d_out) x (d_out, seq_len) -> (seq_len, seq_len)        scores = Q @ K.transpose(-2, -1) / K.shape[-1]**0.5        # (seq_len, seq_len)        weights = F.softmax(scores, dim=-1)        # (seq_len, seq_len) x (seq_len, d_out) -> # (seq_len, d_out)        context = weights @ V        return context
torch.manual_seed(123)sa = SelfAttentionV2(4, 4)output = sa(X)print(output)

tensor([[ 0.1318, -0.1000, -0.4239, -0.0858],

[-0.0532, 0.2164, -0.8386, -0.1107],

[ 0.2318, -0.2270, -0.4083, -0.0919],

[ 0.4762, -0.5514, -0.2901, -0.0859],

[ 0.0700, -0.0399, -0.3281, -0.0728]], grad_fn=)

请特别注意其中tensor维度的变化。

注:此处使用了nn.Linear构建线性层来初始化Q权重,也可以使用nn.Parameter(torch.rand(d_in, d_out))手动创建参数矩阵。不过二者的内部初始化方式略有不同。

Casual Attention: Mask future words

上述Attention weights的计算包含了整个context,但这与生成式大模型的训练过程并不一致,如:

“He sat by the bank of the river.”

当模型正在尝试生成bank的时候,context中只能包含前面的单词,而不能包含后面的river等单词。因为如果在训练阶段,我们允许模型看到全部context,那么训练出来的模型泛化能力较差;当遇到真正的生成任务的时候,效果不佳。因此,我们需要把模型不应该看到的“未来词”挡住,以更好地提升模型能力。

在Embedding章节,我们已经知道大模型的训练是一个自回归过程,如下:

Once --> upon

Once upon --> a

Once upon a --> time

Once upon a time --> there

Once upon a time there --> were

那其实遮挡未来词就变得非常简单,只需要把上述Attention中对角线以上的元素全部去掉即可。

如,我们可以借助如下的下三角矩阵,通过矩阵运算,轻松mask掉未来token。mask矩阵如下:

import torchcontext_size = attention_scores.shape[0]# Lower triangular maskmask = torch.tril(torch.ones(context_size, context_size))print(mask)mask = mask.masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, 0.0)print(mask)

tensor([[1., 0., 0., 0., 0.],

[1., 1., 0., 0., 0.],

[1., 1., 1., 0., 0.],

[1., 1., 1., 1., 0.],

[1., 1., 1., 1., 1.]])

tensor([[0., -inf, -inf, -inf, -inf],

[0., 0., -inf, -inf, -inf],

[0., 0., 0., -inf, -inf],

[0., 0., 0., 0., -inf],

[0., 0., 0., 0., 0.]])

我们最终得到了下三角是0、上三角是负无穷的矩阵。

而mask的过程就是简单的矩阵加法,如下:

print("original scores: \n", attention_scores)# Apply mask to scoresmasked_scores = attention_scores + maskprint("masked scores:\n", masked_scores)

original scores:

tensor([[ 0.1551, -1.0237, 0.3512, 0.9140, 0.5323],

[-1.2857, 8.7238, -2.7508, -7.3460, -4.6522],

[ 0.3042, -1.4816, 0.7240, 1.5888, 0.7321],

[ 1.4368, -7.3169, 3.2298, 7.3577, 3.7078],

[ 0.4611, -4.0977, 0.9404, 2.9979, 2.2575]], grad_fn=)

masked scores:

tensor([[ 0.1551, -inf, -inf, -inf, -inf],

[-1.2857, 8.7238, -inf, -inf, -inf],

[ 0.3042, -1.4816, 0.7240, -inf, -inf],

[ 1.4368, -7.3169, 3.2298, 7.3577, -inf],

[ 0.4611, -4.0977, 0.9404, 2.9979, 2.2575]], grad_fn=)

可见,我们只保留了Attention scores的下三角部分,上三角被填充为-inf。使用-inf是因为后续需要做softmax运算,而softmax(-inf)=0,对weights的计算就是零贡献。

Dropout

另外,为了提升模型泛化能力,经常使用的一种技术是随机丢弃,即dropout。我们用pytorch代码简单示例dropout如下:

import torchtorch.manual_seed(123)# Create a dropout layer with 20% dropout ratedropout = torch.nn.Dropout(0.2)dropout.train()  # Explicitly set to training mode to enable dropoutexample = torch.ones(5, 5)print("Input tensor:\n",example)# Apply dropout to the input tensoroutput = dropout(example)print("tensor after Dropout:\n",output)print(f"Number of zeros in output: {(output == 0).sum().item()}")print(f"Output mean value (should be ~1.0 due to scaling): {output.mean().item():.4f}")

Input tensor:

tensor([[1., 1., 1., 1., 1.],

[1., 1., 1., 1., 1.],

[1., 1., 1., 1., 1.],

[1., 1., 1., 1., 1.],

[1., 1., 1., 1., 1.]])

tensor after Dropout:

tensor([[1.2500, 1.2500, 1.2500, 1.2500, 1.2500],

[1.2500, 1.2500, 1.2500, 0.0000, 1.2500],

[0.0000, 1.2500, 1.2500, 1.2500, 1.2500],

[1.2500, 1.2500, 1.2500, 1.2500, 1.2500],

[1.2500, 1.2500, 1.2500, 1.2500, 1.2500]])

Number of zeros in output: 2

Output mean value (should be ~1.0 due to scaling): 1.1500

可见,在5x5的全1矩阵中,部分值被置为了0,且剩余值变成了1.25。这是因为pytorch在做dropout的时候,为了保持整体的均值不变,按照scale=1/(1-drop_rate)做了缩放。

不过看起来上述结果并不太符合预期,这是因为示例用的矩阵维度太小了,别忘了统计概率只对大数据生效。只需要把上面的维度改为500x500,就可以看到mean依然非常接近1。在gpt2的真实环境中,如前所述,这一步weights的维度是seq_len x seq_len,也就1024 x 1024,其实是非常巨大的。

在Attention机制中,dropout是作用在softmax后得到的weights上的,代码如下:

weights = F.softmax(masked_scores, dim=-1)print("weights after mask: \n", weights)torch.manual_seed(123)output = dropout(weights)print("weights after Dropout: \n", output)

weights after mask:

tensor([[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],

[4.4967e-05, 9.9996e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00],

[3.7185e-01, 6.2345e-02, 5.6581e-01, 0.0000e+00, 0.0000e+00],

[2.6332e-03, 4.1573e-07, 1.5819e-02, 9.8155e-01, 0.0000e+00],

[4.6963e-02, 4.9191e-04, 7.5844e-02, 5.9361e-01, 2.8309e-01]],

grad_fn=)

weights after Dropout:

tensor([[1.2500e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],

[5.6209e-05, 1.2499e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],

[0.0000e+00, 7.7931e-02, 7.0726e-01, 0.0000e+00, 0.0000e+00],

[3.2914e-03, 5.1966e-07, 1.9774e-02, 1.2269e+00, 0.0000e+00],

[5.8704e-02, 6.1489e-04, 9.4804e-02, 7.4201e-01, 3.5387e-01]],

grad_fn=)

可见,softmax作用于masked scores,得到的是仅有下三角的weights;weights经过dropout时候,有20%的值被置为0,剩余值被sacle到1.25倍。

Casual Self-Attention Code

我们在前面SimpleSelfAttention的基础上,增加mask和dropout,得到完整代码如下:

import torchimport torch.nn as nnclass CausalAttention(nn.Module):    """Implements single-head causal self-attention with optional dropout."""def __init__(self, d_in, d_out, context_length, dropout=0.0, qkv_bias=False):        super().__init__()        # (d_in, d_out)        self.W_Q = nn.Linear(d_in, d_out, bias=qkv_bias)        self.W_K = nn.Linear(d_in, d_out, bias=qkv_bias)        self.W_V = nn.Linear(d_in, d_out, bias=qkv_bias)        self.dropout = nn.Dropout(dropout)        # Create a fixed causal mask (upper triangular) [1 means "mask"]        mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)        self.register_buffer("mask", mask.bool())    def forward(self, x):        # x: shape (batch_size, seq_len, d_in)        batch_size, seq_len, _ = x.size()        Q = self.W_Q(x)        K = self.W_K(x)        V = self.W_V(x)        # Compute attention scores        scores = Q @ K.transpose(-2, -1) / (d_out ** 0.5)  # (batch_size, seq_len, seq_len)        # Apply causal mask        scores = scores.masked_fill(self.mask[:seq_len, :seq_len], -torch.inf)        # Compute softmax weights and apply dropout        weights = torch.softmax(scores, dim=-1)        weights = self.dropout(weights)        # Compute output        output = weights @ V  # (batch_size, seq_len, d_out)        return output

在这段代码中,为了更与真实环境贴合,我们为输入X增加了一个batch的维度,其维度变为:(batch_size, seq_len, d_in),而最终得到的output的维度也相应变为(batch_size, seq_len, d_out)。

我们通过生成2批次、最大长度为5、维度是4的随机矩阵,来模拟通过上述CausalAttention来计算最终的context vector,代码如下:

torch.manual_seed(123)batch = torch.randn(2, 5, 4)  # (batch_size=2, seq_len=5, d_in=4)d_in = 4d_out = 4context_length = batch.size(1)ca = CausalAttention(d_in, d_out, context_length, dropout=0.0)context_vecs = ca(batch)print("context_vecs.shape:", context_vecs.shape)print("context_vecs:\n", context_vecs)

context_vecs.shape: torch.Size([2, 5, 4])

context_vecs:

tensor([[[-0.0487, -0.0112, 0.0449, 0.3506],

[ 0.0439, 0.1278, 0.1848, 0.1733],

[-0.2467, -0.1078, 0.2722, 0.5128],

[-0.1638, 0.0053, 0.3753, 0.3111],

[ 0.0264, 0.1455, 0.3622, 0.0182]],

[[ 0.0960, 0.4257, 1.7419, 0.2045],

[-0.0967, 0.2774, 1.1946, 0.5023],

[ 0.1017, 0.2037, 0.4849, 0.1862],

[-0.0775, 0.1062, 0.3737, 0.3387],

[-0.1181, -0.0113, 0.1070, 0.2743]]], grad_fn=)

请特别注意,最终生成的context vector的维度一定是与输入vector的维度是完全一致的。

至此,我们完成了Attention论文中单头注意力的完整计算。

Multi-Head Attention

前面我们完成了单头注意力的计算,为了进一步提升模型的表达能力,又进一步引入多头注意力计算,如上图的Attention论文图片所示。

比如在

“The cat that the dog chased was black.”

这个例子中,利用多个注意力头,可以分别关注不同的语义结构:

"cat" <------ "was"(Head 1 强 attention)Head 1:关注主语-谓语(句子主干)

"that", "dog", "chased" ----> "cat"(Head 2 强 attention)Head 2:关注定语从句的修饰结构

"dog" ----> "chased"(Head 3 强 attention)Head 3:关注宾语结构

"cat" <---- "black"(Head 4 attention)Head 4:关注形容词修饰关系

Concat heads code

多头的最直接实现,是直接把上述单头重复多次,然后堆叠在一起,代码如下:

class MultiHeadAttentionWrapper(nn.Module):    """Implements multi-head self-attention by stacking multiple heads."""def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):        super().__init__()        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"        self.head_dim = d_out // num_heads        self.heads = nn.ModuleList(            [CausalAttention(d_in, self.head_dim, context_length, dropout, qkv_bias) for _ in range(num_heads)])    def forward(self, x):        output = torch.cat([head(x) for head in self.heads], dim=-1)        return output
torch.manual_seed(123)batch = torch.randn(2, 5, 6)  # (batch_size=2, seq_len=5, d_in=6)d_in = 6d_out = 6context_length = batch.size(1)mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, dropout=0.0,num_heads=2)context_vecs = mha(batch)print("context_vecs.shape:", context_vecs.shape)print("context_vecs:\n", context_vecs)

context_vecs.shape: torch.Size([2, 5, 6])

context_vecs:

tensor([[[-0.0067, -0.0370, 0.2712, -0.5243, -0.0242, -0.0438],

[-0.1782, 0.0173, -0.0166, -0.2391, -0.0284, 0.2177],

[-0.1541, 0.2878, -0.2018, 0.2535, 0.0242, 0.3002],

[-0.2817, 0.5219, -0.0699, 0.5508, -0.2767, 0.3709],

[-0.0355, -0.1721, 0.0981, 0.2389, -0.1460, 0.1938]],

[[ 0.7943, -1.9382, 0.2171, -1.6710, 0.7970, -1.3094],

[ 0.2519, -1.1446, 0.2991, -1.5203, 0.3135, -0.9541],

[ 0.1920, -0.8646, 0.3794, -0.9135, 0.0203, -0.5454],

[ 0.2565, -0.8320, 0.1292, -0.9259, 0.2156, -0.4762],

[ 0.1519, -0.5043, 0.1079, -0.3281, 0.1523, -0.1446]]],

grad_fn=)

在上面代码中,我们手动模拟了2个head,并通过d_out // num_heads调整了单个head的维度。

Weight split code

但其实上面代码并不是真正的MHA(Multi-Head-Attention)的实现,堆叠矩阵的方式效率较低,更好的方式是先用大矩阵做一次大投影,然后再拆分,相当于weight splits。

代码如下:

import torchimport torch.nn as nnclass MultiHeadAttention(nn.Module):    """Implements multi-head attention by splitting the attention matrix into multiple heads."""def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):        super().__init__()        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"        self.num_heads = num_heads        self.head_dim = d_out // num_heads        self.W_Q = nn.Linear(d_in, d_out, bias=qkv_bias)        self.W_K = nn.Linear(d_in, d_out, bias=qkv_bias)        self.W_V = nn.Linear(d_in, d_out, bias=qkv_bias)        self.W_O = nn.Linear(d_out, d_out)        self.dropout = nn.Dropout(dropout)        mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)        self.register_buffer("mask", mask.bool())    def forward(self, x):        # shape (batch_size, seq_len, d_in)        batch_size, seq_len, _ = x.size()        # Split Q, K, V into multiple heads        # (batch_size, seq_len, d_in) -> (batch_size, seq_len, d_out) ->        # -> (batch_size, seq_len, num_heads, head_dim) -> (batch_size, num_heads, seq_len, head_dim)        Q = self.W_Q(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)        K = self.W_K(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)        V = self.W_V(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)        # Compute attention scores        scores = Q @ K.transpose(-2, -1) / (d_out ** 0.5)  # (batch_size, num_heads, seq_len, seq_len)        # Apply causal mask        scores = scores.masked_fill(self.mask[:seq_len, :seq_len], -torch.inf)        # Compute softmax weights and apply dropout        weights = torch.softmax(scores, dim=-1)        weights = self.dropout(weights)        # Compute output        output = weights @ V  # (batch_size, num_heads, seq_len, head_dim)        # Concatenate heads and project to output dimension        # (batch_size, num_heads, seq_len, head_dim) -> (batch_size, seq_len, num_heads, head_dim)        # ->   (batch_size, seq_len, d_out)        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)        # Should be helpful, but not strictly necessary.        output = self.W_O(output)        return output
torch.manual_seed(123)batch = torch.randn(2, 5, 6)  # (batch_size=2, seq_len=5, d_in=6)d_in = 6d_out = 6context_length = batch.size(1)mha = MultiHeadAttention(d_in, d_out, context_length, dropout=0.0,num_heads=2)context_vecs = mha(batch)print("context_vecs.shape:", context_vecs.shape)print("context_vecs:\n", context_vecs)

context_vecs.shape: torch.Size([2, 5, 6])

context_vecs:

tensor([[[-0.5829, -0.5644, 0.1930, -0.1541, 0.2518, -0.2252],

[-0.2962, -0.2681, 0.1179, 0.1136, 0.0953, -0.4015],

[-0.2039, -0.0745, 0.1557, -0.0494, 0.1125, -0.5282],

[-0.2540, 0.1181, 0.2729, -0.1182, 0.0321, -0.5292],

[-0.2007, 0.0280, 0.1645, -0.0798, 0.1264, -0.5020]],

[[-0.2307, -1.7354, -0.4065, 0.3778, 0.9090, -0.1498],

[-0.5355, -1.2480, -0.0049, 0.1522, 0.5635, -0.0269],

[-0.4674, -0.8466, 0.0176, 0.1337, 0.4053, -0.2230],

[-0.3683, -0.6768, 0.0088, 0.0933, 0.3034, -0.3600],

[-0.2545, -0.5944, -0.0236, 0.0762, 0.3629, -0.3780]]],

grad_fn=)

相比之前:

1) 使用了统一的大矩阵W_Q做投影,然后通过view操作split到多个head中。

2)对output额外做了一层线性映射,以进一步融合多头。不过这一步并非严格必须的。

在操作过程中,我们应该特别留意tensor维度的变化。只要看懂tensor维度的变化,就基本搞清楚了整个计算的逻辑。

至此,我们已经完成了MHA的代码实现,也是GPT2的Transformer架构中最核心的实现。

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

Attention机制 深度学习 Transformer 自然语言处理
相关文章