掘金 人工智能 11小时前
一文搞懂FlashAttention怎么提升速度的?
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

本文深入探讨了FlashAttention,一种用于加速Transformer模型中Attention机制的创新方法。通过IO感知的优化策略,FlashAttention在不损失精度的前提下,显著减少了显存占用并提升了计算速度,解决了传统Attention在处理长序列时面临的性能瓶颈。文章详细介绍了FlashAttention的核心原理,包括Tile-Based计算、Recomputation策略以及Block Sparse FlashAttention,并分析了其在实际应用中的优势与局限性。

💡 FlashAttention的核心在于使用Tile-Based计算,将Q、K、V分块处理,利用GPU的SRAM(Shared Memory)完成QKᵀ和softmax的计算,避免中间结果写入HBM(高带宽显存),从而减少显存占用。

🔑 为了节省存储中间的softmax权重,FlashAttention采用了Recomputation策略,在前向传播时,不保留中间矩阵,在反向传播阶段重新计算,以“以算代存”的方式,显著降低显存开销,尤其是在长序列上效果明显。

⚙️ Block Sparse FlashAttention在FlashAttention的基础上,结合了IO-aware高效计算方式和block-sparse attention mask的稀疏结构,通过只计算被稀疏掩码指定的QK块对,进一步提升了处理超长序列的性能和可扩展性。

⚠️ FlashAttention在长序列任务中可将显存减少2-4倍,速度提升达2-4倍,但仍存在线程并行效率不高、不支持MQA/GQA等高效注意力结构、实现依赖Triton编译器以及反向传播内核优化空间等不足或限制。

本篇文章是Transformer系列的第七篇。

Transformer系列文章:

一览Transformer整体架构
Transformer——Attention怎么实现集中注意力
Transformer——FeedForward模块在干什么?
从0开始实现Transformer
什么是KV-Cache
Transformer注意力机制——MHA&MQA&GQA

所有相关源码示例、流程图、模型配置与知识库构建技巧,我也将持续更新在Github:LLMHub,欢迎关注收藏!

希望大家带着下面的问题来学习,我会在文末给出答案。

    传统 Attention 的主要性能瓶颈在哪里?为什么需要 FlashAttention?FlashAttention 是如何利用 shared memory 降低显存占用并提高速度的?FlashAttention 在实际应用中还有哪些不足或限制?

一、引言

Transformer 模型自诞生以来,已成为自然语言处理、计算机视觉、语音等领域的核心架构。而 Attention 机制作为 Transformer 的核心计算模块,其计算复杂度和显存占用在处理长序列时常常成为性能瓶颈。传统 Attention 的时间和空间复杂度为 O(n^2),这在大规模模型或长文本输入中表现为效率低下、显存不足。

FlashAttention 是由 Stanford Hazy Research 团队提出的一种高效实现方式,专为 GPU 设计,通过 I/O 感知(IO-aware)的优化策略,在不损失精度的前提下显著加速 Attention 计算并降低显存占用。

二、Attention回顾

在标准 Transformer 中,Attention 计算如下:

其中:

计算流程如下:

    计算 QK^T:需存储一个n*n 矩阵;计算 softmax;乘以 V 得到最终输出。

瓶颈分析:

三、FlashAttention 的核心思想(核心原理)

FlashAttention 的目标是:降低显存占用,同时提升速度,具体做法包括:

1. Tile-Based 计算

将 Q, K, V 分块为小块(tile),每次仅处理一小块:

上面图中左半部分是计算机的内存分布, HBM 是 “High Bandwidth Memory” 的缩写,也就是高带宽显存,是一种专为高性能计算和显存密集型任务(如 GPU、AI 加速、图形渲染等)设计的下一代显存技术。 SRAM是一种静态随机访问存储器,用于高速缓存等内部存储器,具有更快的访问速度和更低的延迟,但成本更高且占用更多芯片空间。

标准Attention的计算算法如下:

可以看到,标准 Attention 实现大量中间结果需频繁访问 HBM,而 HBM 的访问速度远远低于 GPU 的SRAM。因此 FlashAttention 通过“tile 计算+显存访问优化”方案,减少了对 HBM 的依赖,提高了整体执行效率

softmax计算公式如下:

为了数值稳定性,FlashAttention采用Safe Softmax,对于向量x

同理,对于向量x = [ x1,x2],softmax可以分解计算:

这就说明即使Q,K,V被分成块也是可以计算softmax的。

2. Recomputation Strategy

为了节省存储中间的 softmax 权重,FlashAttention 在需要时重新计算部分内容,避免保存完整矩阵。

标准Attention的反向传播算法如下,其中P代表Softmax(QKᵀ / √dₖ),也就是注意力权重矩阵。

结合着Attention的计算公式更好理解

在标准 Attention 实现中,为了完成前向传播和反向传播,我们通常需要保存如下中间结果:

这些矩阵很大,尤其是在处理长序列时,显存消耗会非常高。

FlashAttention 为了降低显存占用,采取了一种策略:

在前向传播时 不保留中间矩阵,而是到了反向传播阶段 再把它们重新计算出来

以 softmax 的 attention score 为例:

QKᵀ → softmax → 缓存在显存中 → 用于乘V和反向传播
QKᵀ → softmax → 直接用于乘V,不缓存...后面反向传播需要用到 softmax → 再重新计算一次 QKᵀ 和 softmax

这就节省了存 softmax 的显存开销,尤其在长序列上非常可观。

FlashAttention的前向传播算法如下:

FlashAttention的反向传播的过程如下:

可以看到其中没有存储,反向传播的过程中需要的数据都是重新计算的,这种“以算代存”的方式是一种典型的时间换空间(compute vs. memory)策略。虽然多计算一次会略微增加一点时间,但显存节省得非常明显,反而提升了整体性能,因为:

3.Block Sparse FlashAttention

传统 Attention 是 全连接的:每个 token 都和所有其他 token 交互,计算量为 O(n^2)。

Sparse Attention 只计算部分 token 对的关系,常见稀疏模式包括:

在FlashAttention 的基础上,为了进一步提升处理超长序列的性能和可扩展性,Block Sparse FlashAttention 结合了 FlashAttention 的 IO-aware 高效计算方式和 block-sparse attention mask 的稀疏结构,从而实现 更少计算 + 更少显存占用 的 attention 操作。

Block Sparse FlashAttention 的关键是在 FlashAttention 高效计算的基础上,只计算被稀疏掩码指定的 QK 块对,算法如下:

    输入:Q、K、V 被划分为若干 block;依据稀疏掩码(mask)决定哪些 Q-block 要与哪些 K-block 交互;对每个有效块对,执行 FlashAttention 核心流程(QKᵀ → softmax → ×V);将结果整合,拼接为完整输出。

四、FlashAttention vs 标准Attention

项目原始 AttentionFlashAttention
时间复杂度O(n^2 d)O(n^2 d),但更快
显存消耗高(存储中间矩阵)低(tile重计算)
速度表现慢(受限于显存读写)快(高效访存)
精度控制float32 为主支持 fp16 / bf16

在长序列任务中,FlashAttention 可将显存减少 2-4 倍,速度提升达 2-4 倍。

五、从0手撸FlashAttention

for i in range(0, N, block_size):    q_block = q[:, i:i+block_size]  # [B, Bq, D]    max_score = None    row_sum_exp = None    acc = torch.zeros_like(q_block)    for j in range(0, N, block_size):        k_block = k[:, j:j+block_size]  # [B, Bk, D]        v_block = v[:, j:j+block_size]  # [B, Bk, D]        # 1. Attention logits        scores = torch.bmm(q_block, k_block.transpose(1, 2)) * scale  # [B, Bq, Bk]        # 2. Numerical stability        block_max = scores.max(dim=-1, keepdim=True).values  # [B, Bq, 1]        scores = scores - block_max        exp_scores = scores.exp()  # [B, Bq, Bk]        # 3. Dropout (可选)        if dropout_p > 0.0:            exp_scores = F.dropout(exp_scores, p=dropout_p, training=True)        # 4. Weighted sum        acc += torch.bmm(exp_scores, v_block)  # [B, Bq, D]        # 5. Softmax normalization (log-sum-exp trick)        block_sum = exp_scores.sum(dim=-1, keepdim=True)  # [B, Bq, 1]        if row_sum_exp is None:            row_sum_exp = block_sum            max_score = block_max        else:            row_sum_exp += block_sum            max_score = torch.max(max_score, block_max)    # Normalize accumulated result    output[:, i:i+block_size] = acc / (row_sum_exp + 1e-6)return output

要注意的是 上面的PyTorch 实现并没有用到 Shared Memory,它只是演示了 FlashAttention 的思想流程。
真正利用了 SRAM 的,是 FlashAttention 的 CUDA kernel 或 Triton kernel 实现

如果想要测试效率,可以直接调用torch封装好的flashattention

from flash_attn.modules.mha import FlashMHAimport torchx = torch.randn(8, 512, 512, device='cuda')  # batch, seq_len, dimmha = FlashMHA(embed_dim=512, num_heads=8, device='cuda')output = mha(x)print(output.shape)  # [8, 512, 512]

六、总结

FlashAttention 提供了一种高效、低显存的 Attention 实现方式,极大地缓解了 Transformer 模型在长序列处理中的性能瓶颈。在当前大模型时代,FlashAttention 成为高效训练与部署的关键组件之一。

最后,我们回答一下文章开头提出的问题。

    传统 Attention 的主要性能瓶颈在哪里?为什么需要 FlashAttention?

标准的 Attention 实现存在两个严重问题:

FlashAttention 被提出,目标就是通过“在 shared memory 中块级 tile 化 attention 计算”,避免 score 的 materialization 和重复访存,从而提升效率、减少内存压力。

    FlashAttention 是如何利用 shared memory 降低显存占用并提高速度的?

FlashAttention 的关键设计是:

    FlashAttention 在实际应用中还有哪些不足或限制?

    尽管 FlashAttention 在性能和显存方面带来显著改善,但也存在一些实际问题:

关于深度学习和大模型相关的知识和前沿技术更新,请关注公众号算法coting!

以上内容部分参考了

FlashAttention:Fast and Memory-Efficient Exact Attention with IO-Awareness

Flash Attention原理详解(含代码讲解)

非常感谢,如有侵权请联系删除!

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

FlashAttention Transformer Attention机制 深度学习 模型优化
相关文章