多头自注意力机制(Multi-Head Attention, MHA)
通过并行计算多个注意力头,使模型能够同时关注输入序列中不同位置的特征。其核心思想是将输入映射到多个子空间,分别计算注意力权重并聚合结果,从而增强模型对复杂模式的捕捉能力。
import torchimport torch.nn as nnimport torch.nn.functional as Fclass MultiHeadAttention(nn.Module): def __init__(self, embed_size, heads): super(MultiHeadAttention, self).__init__() self.embed_size = embed_size # 嵌入维度大小 self.heads = heads # 头的数量 self.head_dim = embed_size // heads # 每个头的维度 # 确保嵌入维度可以被头的数量整除 assert ( self.head_dim * heads == embed_size ), "Embedding size needs to be divisible by heads" # 定义线性变换层,用于将输入转换为查询、键和值 self.values = nn.Linear(self.head_dim, self.head_dim, bias=False) self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False) self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False) # 最终输出的全连接层 self.fc_out = nn.Linear(heads * self.head_dim, embed_size) def forward(self, values, keys, query, mask): N = query.shape[0] # 批量大小 value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1] # 将嵌入向量分割成多个头 values = values.reshape(N, value_len, self.heads, self.head_dim) keys = keys.reshape(N, key_len, self.heads, self.head_dim) queries = query.reshape(N, query_len, self.heads, self.head_dim) # 对每个头进行线性变换 values = self.values(values) keys = self.keys(keys) queries = self.queries(queries) # 计算点积注意力分数 energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys]) # queries shape: (N, query_len, heads, head_dim) # keys shape: (N, key_len, heads, head_dim) # energy shape: (N, heads, query_len, key_len) # 应用掩码(如果存在) if mask is not None: energy = energy.masked_fill(mask == 0, float("-1e20")) # 对注意力分数应用softmax函数 attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3) # 计算加权求和后的输出 out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape( N, query_len, self.heads * self.head_dim ) # attention shape: (N, heads, query_len, key_len) # values shape: (N, value_len, heads, heads_dim) # out after matrix multiply: (N, query_len, heads, head_dim), then flatten last two dimensions # 通过全连接层得到最终输出 out = self.fc_out(out) return out# 示例用法:if __name__ == "__main__": embed_size = 512 # 嵌入维度大小 heads = 8 # 头的数量 batch_size = 64 # 批量大小 seq_length = 10 # 序列长度 mha = MultiHeadAttention(embed_size, heads) values = keys = query = torch.randn(batch_size, seq_length, embed_size) print(mha.forward(values, keys, query, mask=None).shape)
多查询注意力机制(Multi-Query Attention,MQA)
Multi-Query Attention (MQA) 是对多头注意力(MHA)的高效改进版本,其核心思想是共享键(Key)和值(Value)的投影参数,仅对查询(Query)使用独立的头参数。这种方法显著减少了模型参数量和计算复杂度,同时保留了多头注意力的部分并行性优势。
import torchimport torch.nn as nnfrom thop import profileclass MultiQueryAttention(nn.Module): def __init__(self, hidden_size, num_heads, dropout=0.0): """ Multi-Query Attention 的实现。 Args: hidden_size (int): 输入特征的维度,也即 hidden_state 的最后一维。 num_heads (int): 注意力头的数量。 dropout (float): dropout 的概率,默认为 0.0。 """ super(MultiQueryAttention, self).__init__() assert hidden_size % num_heads == 0, "hidden_size 必须能被 num_heads 整除" self.hidden_size = hidden_size self.num_heads = num_heads self.head_dim = hidden_size // num_heads # 每个头的维度 # 定义线性变换层,用于生成 Q, K, V self.query = nn.Linear(hidden_size, hidden_size) # 每个头独立的 Query self.key = nn.Linear(hidden_size, self.head_dim) # 所有头共享的 Key self.value = nn.Linear(hidden_size, self.head_dim) # 所有头共享的 Value self.dropout = nn.Dropout(dropout) self.out_projection = nn.Linear(hidden_size, hidden_size) def forward(self, query, key, value, mask=None): N = query.shape[0] # 批量大小 query_len = query.shape[1] key_len = key.shape[1] value_len = value.shape[1] # 将嵌入向量分割成多个头 query = self.query(query).view(N, query_len, self.num_heads, self.head_dim) key = self.key(key).view(N, key_len, self.num_heads, self.head_dim) value = self.value(value).view(N, value_len, self.num_heads, self.head_dim) # 调整形状以便进行点积操作 query = query.transpose(1, 2) # shape: (N, heads, query_len, head_dim) key = key.transpose(1, 2) # shape: (N, heads, key_len, head_dim) value = value.transpose(1, 2) # shape: (N, heads, value_len, head_dim) # 计算点积注意力分数 scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim)) # scores shape: (N, heads, query_len, key_len) # 应用掩码(如果存在) if mask is not None: scores = scores.masked_fill(mask == 0, float('-inf')) # 对注意力分数应用softmax函数 attention_weights = torch.softmax(scores, dim=-1) # attention_weights shape: (N, heads, query_len, key_len) # 计算加权求和后的输出 context = torch.matmul(attention_weights, value) # context shape: (N, heads, query_len, head_dim) # 调整形状以便通过全连接层 context = context.transpose(1, 2).contiguous().view(N, query_len, self.hidden_size) # context shape: (N, query_len, hidden_size) # 通过全连接层得到最终输出 output = self.out_projection(context) # output shape: (N, query_len, hidden_size) return output# 示例用法:if __name__ == "__main__": hidden_size = 512 # 嵌入维度大小 num_heads = 8 # 头的数量 batch_size = 64 # 批量大小 seq_length = 10 # 序列长度 mqa = MultiQueryAttention(hidden_size, num_heads) query = key = value = torch.randn(batch_size, seq_length, hidden_size) print(mqa.forward(query, key, value, mask=None).shape)
分组查询注意力机制(Grouped Query Attention,GQA)
Grouped Query Attention (GQA) 是对多头注意力(MHA)和多查询注意力(MQA)的折中优化方案。其核心思想是将查询头(Query Heads)划分为多个组(Group),每组内的查询头共享一组键(Key)和值(Value),从而在保留多头并行性的同时减少参数量和计算复杂度。GQA 在参数效率与模型性能之间取得了平衡,适用于大规模模型的高效部署。
import torchimport torch.nn as nnclass GroupedQueryAttention(nn.Module): def __init__(self, hidden_size, num_heads, group_size=2, dropout=0.0): """ Grouped Query Attention 实现。 Args: hidden_size (int): 输入特征的维度。 num_heads (int): 查询头的数量。 group_size (int): 每个组中包含的查询头数量。 dropout (float): dropout 的概率。 """ super(GroupedQueryAttention, self).__init__() assert hidden_size % num_heads == 0, "hidden_size 必须能被 num_heads 整除" assert num_heads % group_size == 0, "num_heads 必须能被 group_size 整除" self.hidden_size = hidden_size self.num_heads = num_heads self.group_size = group_size self.group_num = num_heads // group_size self.head_dim = hidden_size // num_heads # 每个头的维度 # 定义线性变换层,用于生成 Q, K, V self.query = nn.Linear(hidden_size, hidden_size) # 每个头独立的 Query self.key = nn.Linear(hidden_size, self.group_num * self.head_dim) # 分组共享的 Key self.value = nn.Linear(hidden_size, self.group_num * self.head_dim) # 分组共享的 Value self.dropout = nn.Dropout(dropout) self.out_projection = nn.Linear(hidden_size, hidden_size) def forward(self, query, key, value, mask=None): N = query.shape[0] # 批量大小 query_len = query.shape[1] key_len = key.shape[1] value_len = value.shape[1] # 将嵌入向量分割成多个头 query = self.query(query).view(N, query_len, self.num_heads, self.head_dim) key = self.key(key).view(N, key_len, self.group_num, self.head_dim) value = self.value(value).view(N, value_len, self.group_num, self.head_dim) # 调整形状以便进行点积操作 query = query.transpose(1, 2) # shape: (N, heads, query_len, head_dim) key = key.transpose(1, 2) # shape: (N, group_num, key_len, head_dim) value = value.transpose(1, 2) # shape: (N, group_num, value_len, head_dim) # 计算点积注意力分数 scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim)) # scores shape: (N, heads, query_len, key_len) # 应用掩码(如果存在) if mask is not None: scores = scores.masked_fill(mask == 0, float('-inf')) # 对注意力分数应用softmax函数 attention_weights = torch.softmax(scores, dim=-1) # attention_weights shape: (N, heads, query_len, key_len) # 计算加权求和后的输出 context = torch.matmul(attention_weights, value) # context shape: (N, heads, query_len, head_dim) # 合并组内的头 context = context.view(N, query_len, self.num_heads, self.head_dim) context = context.permute(0, 2, 1, 3).contiguous().view(N, query_len, self.hidden_size) # context shape: (N, query_len, hidden_size) # 通过全连接层得到最终输出 output = self.out_projection(context) # output shape: (N, query_len, hidden_size) return output# 示例用法:if __name__ == "__main__": hidden_size = 512 # 嵌入维度大小 num_heads = 8 # 头的数量 group_size = 2 # 每个组中包含的查询头数量 batch_size = 64 # 批量大小 seq_length = 10 # 序列长度 gqa = GroupedQueryAttention(hidden_size, num_heads, group_size) query = key = value = torch.randn(batch_size, seq_length, hidden_size) print(gqa.forward(query, key, value, mask=None).shape)
多头潜在注意力(Multi-Head Latent Attention, MLA)
Multi-Head Latent Attention (MLA) 是一种结合低秩参数化与旋转位置编码(RoPE)的高效注意力机制。其核心思想是通过低秩投影压缩查询(Q)、键(K)、值(V)的维度,并在注意力计算中解耦内容与位置信息,从而减少计算复杂度,同时保留长距离依赖建模能力。MLA 特别适用于大规模模型的部署,平衡了效率与性能。
import torchimport torch.nn as nnimport mathclass RotaryEmbedding(nn.Module): def __init__(self, hidden_size, num_heads, base=10000, max_len=512): """ RoPE位置编码模块 Args: hidden_size (int): 模型维度 num_heads (int): 注意力头数量 base (int): 频率基值 max_len (int): 最大序列长度 """ super().__init__() self.head_dim = hidden_size // num_heads self.hidden_size = hidden_size self.num_heads = num_heads self.base = base self.max_len = max_len self.cos_pos_cache, self.sin_pos_cache = self._compute_pos_emb() def _compute_pos_emb(self): # 计算频率因子 theta_i = 1. / (self.base ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim)) # 创建位置索引 positions = torch.arange(self.max_len) # 计算位置编码 pos_emb = positions.unsqueeze(1) * theta_i.unsqueeze(0) # 计算cos和sin的位置编码 cos_pos = pos_emb.sin().repeat_interleave(2, dim=-1) sin_pos = pos_emb.cos().repeat_interleave(2, dim=-1) return cos_pos, sin_pos def forward(self, x, seq_len=None): if seq_len is None: seq_len = x.size(1) # 获取缓存中的位置编码 cos_pos = self.cos_pos_cache[:seq_len].unsqueeze(1) sin_pos = self.sin_pos_cache[:seq_len].unsqueeze(1) # 分割x为奇数和偶数部分 x1, x2 = x.chunk(2, dim=-1) # 应用旋转位置编码 rotated_x1 = x1 * cos_pos - x2 * sin_pos rotated_x2 = x1 * sin_pos + x2 * cos_pos # 合并结果 rotated_x = torch.cat([rotated_x1, rotated_x2], dim=-1) return rotated_x# 示例用法:if __name__ == "__main__": hidden_size = 512 # 模型维度 num_heads = 8 # 注意力头数量 batch_size = 64 # 批量大小 seq_length = 10 # 序列长度 rope = RotaryEmbedding(hidden_size, num_heads) x = torch.randn(batch_size, seq_length, hidden_size) print(rope(x).shape)