掘金 人工智能 17小时前
手撕MHA、MLA、MQA、GQA
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

本文深入探讨了Transformer模型中几种高效的注意力机制,包括多头自注意力(MHA)、多查询注意力(MQA)、分组查询注意力(GQA)和多头潜在注意力(MLA)。通过对这些机制的原理、实现方式和应用场景的分析,揭示了它们在减少计算复杂度、提高模型效率和性能方面的优势。文章还提供了PyTorch代码示例,帮助读者更好地理解和应用这些技术。

🧠 多头自注意力(MHA)允许模型并行关注输入序列的不同部分,通过将输入映射到多个子空间,分别计算注意力权重并聚合结果,增强模型对复杂模式的捕捉能力。

💡 多查询注意力(MQA)是对MHA的改进,共享键和值的投影参数,仅对查询使用独立的头参数,显著减少模型参数量和计算复杂度,同时保留多头注意力的部分并行性优势。

🤝 分组查询注意力(GQA)是MHA和MQA的折中方案,将查询头划分为多个组,每组共享键和值,在参数效率与模型性能之间取得平衡,适用于大规模模型的高效部署。

✨ 多头潜在注意力(MLA)结合了低秩参数化与旋转位置编码(RoPE),通过低秩投影压缩Q、K、V的维度,解耦内容与位置信息,减少计算复杂度,同时保留长距离依赖建模能力,特别适用于大规模模型的部署。

多头自注意力机制(Multi-Head Attention, MHA)

通过并行计算多个注意力头,使模型能够同时关注输入序列中不同位置的特征。其核心思想是将输入映射到多个子空间,分别计算注意力权重并聚合结果,从而增强模型对复杂模式的捕捉能力。

import torchimport torch.nn as nnimport torch.nn.functional as Fclass MultiHeadAttention(nn.Module):    def __init__(self, embed_size, heads):        super(MultiHeadAttention, self).__init__()        self.embed_size = embed_size  # 嵌入维度大小        self.heads = heads              # 头的数量        self.head_dim = embed_size // heads  # 每个头的维度        # 确保嵌入维度可以被头的数量整除        assert (            self.head_dim * heads == embed_size        ), "Embedding size needs to be divisible by heads"        # 定义线性变换层,用于将输入转换为查询、键和值        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)        # 最终输出的全连接层        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)    def forward(self, values, keys, query, mask):        N = query.shape[0]  # 批量大小        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]        # 将嵌入向量分割成多个头        values = values.reshape(N, value_len, self.heads, self.head_dim)        keys = keys.reshape(N, key_len, self.heads, self.head_dim)        queries = query.reshape(N, query_len, self.heads, self.head_dim)        # 对每个头进行线性变换        values = self.values(values)        keys = self.keys(keys)        queries = self.queries(queries)        # 计算点积注意力分数        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])        # queries shape: (N, query_len, heads, head_dim)        # keys shape: (N, key_len, heads, head_dim)        # energy shape: (N, heads, query_len, key_len)        # 应用掩码(如果存在)        if mask is not None:            energy = energy.masked_fill(mask == 0, float("-1e20"))        # 对注意力分数应用softmax函数        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)        # 计算加权求和后的输出        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(            N, query_len, self.heads * self.head_dim        )        # attention shape: (N, heads, query_len, key_len)        # values shape: (N, value_len, heads, heads_dim)        # out after matrix multiply: (N, query_len, heads, head_dim), then flatten last two dimensions        # 通过全连接层得到最终输出        out = self.fc_out(out)        return out# 示例用法:if __name__ == "__main__":    embed_size = 512  # 嵌入维度大小    heads = 8         # 头的数量    batch_size = 64   # 批量大小    seq_length = 10   # 序列长度    mha = MultiHeadAttention(embed_size, heads)    values = keys = query = torch.randn(batch_size, seq_length, embed_size)    print(mha.forward(values, keys, query, mask=None).shape)

多查询注意力机制(Multi-Query Attention,MQA)

Multi-Query Attention (MQA) 是对多头注意力(MHA)的高效改进版本,其核心思想是共享键(Key)和值(Value)的投影参数,仅对查询(Query)使用独立的头参数。这种方法显著减少了模型参数量和计算复杂度,同时保留了多头注意力的部分并行性优势。

import torchimport torch.nn as nnfrom thop import profileclass MultiQueryAttention(nn.Module):    def __init__(self, hidden_size, num_heads, dropout=0.0):        """        Multi-Query Attention 的实现。        Args:            hidden_size (int): 输入特征的维度,也即 hidden_state 的最后一维。            num_heads (int): 注意力头的数量。            dropout (float): dropout 的概率,默认为 0.0。        """        super(MultiQueryAttention, self).__init__()        assert hidden_size % num_heads == 0, "hidden_size 必须能被 num_heads 整除"                self.hidden_size = hidden_size        self.num_heads = num_heads        self.head_dim = hidden_size // num_heads  # 每个头的维度                # 定义线性变换层,用于生成 Q, K, V        self.query = nn.Linear(hidden_size, hidden_size)  # 每个头独立的 Query        self.key = nn.Linear(hidden_size, self.head_dim)   # 所有头共享的 Key        self.value = nn.Linear(hidden_size, self.head_dim) # 所有头共享的 Value                self.dropout = nn.Dropout(dropout)        self.out_projection = nn.Linear(hidden_size, hidden_size)    def forward(self, query, key, value, mask=None):        N = query.shape[0]  # 批量大小        query_len = query.shape[1]        key_len = key.shape[1]        value_len = value.shape[1]        # 将嵌入向量分割成多个头        query = self.query(query).view(N, query_len, self.num_heads, self.head_dim)        key = self.key(key).view(N, key_len, self.num_heads, self.head_dim)        value = self.value(value).view(N, value_len, self.num_heads, self.head_dim)        # 调整形状以便进行点积操作        query = query.transpose(1, 2)  # shape: (N, heads, query_len, head_dim)        key = key.transpose(1, 2)      # shape: (N, heads, key_len, head_dim)        value = value.transpose(1, 2)    # shape: (N, heads, value_len, head_dim)        # 计算点积注意力分数        scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim))        # scores shape: (N, heads, query_len, key_len)        # 应用掩码(如果存在)        if mask is not None:            scores = scores.masked_fill(mask == 0, float('-inf'))        # 对注意力分数应用softmax函数        attention_weights = torch.softmax(scores, dim=-1)        # attention_weights shape: (N, heads, query_len, key_len)        # 计算加权求和后的输出        context = torch.matmul(attention_weights, value)        # context shape: (N, heads, query_len, head_dim)        # 调整形状以便通过全连接层        context = context.transpose(1, 2).contiguous().view(N, query_len, self.hidden_size)        # context shape: (N, query_len, hidden_size)        # 通过全连接层得到最终输出        output = self.out_projection(context)        # output shape: (N, query_len, hidden_size)        return output# 示例用法:if __name__ == "__main__":    hidden_size = 512  # 嵌入维度大小    num_heads = 8      # 头的数量    batch_size = 64    # 批量大小    seq_length = 10    # 序列长度    mqa = MultiQueryAttention(hidden_size, num_heads)    query = key = value = torch.randn(batch_size, seq_length, hidden_size)    print(mqa.forward(query, key, value, mask=None).shape)

分组查询注意力机制(Grouped Query Attention,GQA)

Grouped Query Attention (GQA) 是对多头注意力(MHA)和多查询注意力(MQA)的折中优化方案。其核心思想是将查询头(Query Heads)划分为多个组(Group),每组内的查询头共享一组键(Key)和值(Value),从而在保留多头并行性的同时减少参数量和计算复杂度。GQA 在参数效率与模型性能之间取得了平衡,适用于大规模模型的高效部署。

import torchimport torch.nn as nnclass GroupedQueryAttention(nn.Module):    def __init__(self, hidden_size, num_heads, group_size=2, dropout=0.0):        """        Grouped Query Attention 实现。        Args:            hidden_size (int): 输入特征的维度。            num_heads (int): 查询头的数量。            group_size (int): 每个组中包含的查询头数量。            dropout (float): dropout 的概率。        """        super(GroupedQueryAttention, self).__init__()        assert hidden_size % num_heads == 0, "hidden_size 必须能被 num_heads 整除"        assert num_heads % group_size == 0, "num_heads 必须能被 group_size 整除"                self.hidden_size = hidden_size        self.num_heads = num_heads        self.group_size = group_size        self.group_num = num_heads // group_size        self.head_dim = hidden_size // num_heads  # 每个头的维度                # 定义线性变换层,用于生成 Q, K, V        self.query = nn.Linear(hidden_size, hidden_size)  # 每个头独立的 Query        self.key = nn.Linear(hidden_size, self.group_num * self.head_dim)  # 分组共享的 Key        self.value = nn.Linear(hidden_size, self.group_num * self.head_dim)  # 分组共享的 Value                self.dropout = nn.Dropout(dropout)        self.out_projection = nn.Linear(hidden_size, hidden_size)    def forward(self, query, key, value, mask=None):        N = query.shape[0]  # 批量大小        query_len = query.shape[1]        key_len = key.shape[1]        value_len = value.shape[1]        # 将嵌入向量分割成多个头        query = self.query(query).view(N, query_len, self.num_heads, self.head_dim)        key = self.key(key).view(N, key_len, self.group_num, self.head_dim)        value = self.value(value).view(N, value_len, self.group_num, self.head_dim)        # 调整形状以便进行点积操作        query = query.transpose(1, 2)  # shape: (N, heads, query_len, head_dim)        key = key.transpose(1, 2)      # shape: (N, group_num, key_len, head_dim)        value = value.transpose(1, 2)    # shape: (N, group_num, value_len, head_dim)        # 计算点积注意力分数        scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim))        # scores shape: (N, heads, query_len, key_len)        # 应用掩码(如果存在)        if mask is not None:            scores = scores.masked_fill(mask == 0, float('-inf'))        # 对注意力分数应用softmax函数        attention_weights = torch.softmax(scores, dim=-1)        # attention_weights shape: (N, heads, query_len, key_len)        # 计算加权求和后的输出        context = torch.matmul(attention_weights, value)        # context shape: (N, heads, query_len, head_dim)        # 合并组内的头        context = context.view(N, query_len, self.num_heads, self.head_dim)        context = context.permute(0, 2, 1, 3).contiguous().view(N, query_len, self.hidden_size)        # context shape: (N, query_len, hidden_size)        # 通过全连接层得到最终输出        output = self.out_projection(context)        # output shape: (N, query_len, hidden_size)        return output# 示例用法:if __name__ == "__main__":    hidden_size = 512  # 嵌入维度大小    num_heads = 8      # 头的数量    group_size = 2     # 每个组中包含的查询头数量    batch_size = 64    # 批量大小    seq_length = 10    # 序列长度    gqa = GroupedQueryAttention(hidden_size, num_heads, group_size)    query = key = value = torch.randn(batch_size, seq_length, hidden_size)    print(gqa.forward(query, key, value, mask=None).shape)

多头潜在注意力(Multi-Head Latent Attention, MLA)

Multi-Head Latent Attention (MLA) 是一种结合低秩参数化与旋转位置编码(RoPE)的高效注意力机制。其核心思想是通过低秩投影压缩查询(Q)、键(K)、值(V)的维度,并在注意力计算中解耦内容与位置信息,从而减少计算复杂度,同时保留长距离依赖建模能力。MLA 特别适用于大规模模型的部署,平衡了效率与性能。

import torchimport torch.nn as nnimport mathclass RotaryEmbedding(nn.Module):    def __init__(self, hidden_size, num_heads, base=10000, max_len=512):        """        RoPE位置编码模块        Args:            hidden_size (int): 模型维度            num_heads (int): 注意力头数量            base (int): 频率基值            max_len (int): 最大序列长度        """        super().__init__()        self.head_dim = hidden_size // num_heads        self.hidden_size = hidden_size        self.num_heads = num_heads        self.base = base        self.max_len = max_len        self.cos_pos_cache, self.sin_pos_cache = self._compute_pos_emb()    def _compute_pos_emb(self):        # 计算频率因子        theta_i = 1. / (self.base ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))        # 创建位置索引        positions = torch.arange(self.max_len)        # 计算位置编码        pos_emb = positions.unsqueeze(1) * theta_i.unsqueeze(0)        # 计算cos和sin的位置编码        cos_pos = pos_emb.sin().repeat_interleave(2, dim=-1)        sin_pos = pos_emb.cos().repeat_interleave(2, dim=-1)        return cos_pos, sin_pos    def forward(self, x, seq_len=None):        if seq_len is None:            seq_len = x.size(1)        # 获取缓存中的位置编码        cos_pos = self.cos_pos_cache[:seq_len].unsqueeze(1)        sin_pos = self.sin_pos_cache[:seq_len].unsqueeze(1)        # 分割x为奇数和偶数部分        x1, x2 = x.chunk(2, dim=-1)        # 应用旋转位置编码        rotated_x1 = x1 * cos_pos - x2 * sin_pos        rotated_x2 = x1 * sin_pos + x2 * cos_pos        # 合并结果        rotated_x = torch.cat([rotated_x1, rotated_x2], dim=-1)        return rotated_x# 示例用法:if __name__ == "__main__":    hidden_size = 512  # 模型维度    num_heads = 8      # 注意力头数量    batch_size = 64    # 批量大小    seq_length = 10    # 序列长度    rope = RotaryEmbedding(hidden_size, num_heads)    x = torch.randn(batch_size, seq_length, hidden_size)    print(rope(x).shape)

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

多头注意力 MHA MQA GQA MLA
相关文章