智源社区 2024年11月21日
【Triton 教程】融合注意力 (Fused Attention)
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

Triton是用于并行编程的语言和编译器,提供基于Python的编程环境,可高效编写自定义DNN计算内核,并能在现代GPU硬件上以最大吞吐量运行,文中还介绍了相关函数的功能及计算逻辑。

🎯Triton是并行编程语言和编译器,基于Python,用于编写DNN计算内核。

📈文中介绍了多个函数,如_attn_fwd_inner等,包含数值范围处理、计算更新等逻辑。

⚙️为保持教程快速,不会每次运行自动调优,保留相关代码方便重新调整。

🧮还提到了一些计算过程,如计算dV、dP、dS、dQ等。

Triton 是一种用于并行编程的语言和编译器。它旨在提供一个基于 Python 的编程环境,以高效编写自定义 DNN 计算内核,并能够在现代 GPU 硬件上以最大吞吐量运行。

import pytestimport torchimport tritonimport triton.language as tldef is_hip():    return triton.runtime.driver.active.get_current_target().backend == "hip"@triton.jitdef _attn_fwd_inner(acc, l_i, m_i, q,  #                    K_block_ptr, V_block_ptr,  #                    start_m, qk_scale,  #                    BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr,  #                    STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr,  #                    N_CTX: tl.constexpr, fp8_v: tl.constexpr):    # range of values handled by this stage    # 这个阶段处理的数值范围    if STAGE == 1:        lo, hi = 0, start_m * BLOCK_M    elif STAGE == 2:        lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M        lo = tl.multiple_of(lo, BLOCK_M)    # causal = False    else:        lo, hi = 0, N_CTX    K_block_ptr = tl.advance(K_block_ptr, (0, lo))    V_block_ptr = tl.advance(V_block_ptr, (lo, 0))    # loop over k, v and update accumulator    # 循环遍历 k 和 v,并更新累加器。    for start_n in range(lo, hi, BLOCK_N):        start_n = tl.multiple_of(start_n, BLOCK_N)        # -- compute qk ----        # -- 计算 qk ----        k = tl.load(K_block_ptr)        qk = tl.dot(q, k)        if STAGE == 2:            mask = offs_m[:, None] >= (start_n + offs_n[None, :])            qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)            m_ij = tl.maximum(m_i, tl.max(qk, 1))            qk -= m_ij[:, None]        else:            m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)            qk = qk * qk_scale - m_ij[:, None]        p = tl.math.exp2(qk)        l_ij = tl.sum(p, 1)        # -- update m_i and l_i        # -- 更新 m_i 和 l_i        alpha = tl.math.exp2(m_i - m_ij)        l_i = l_i * alpha + l_ij        # -- update output accumulator --        # -- 更新输出累加器 --        acc = acc * alpha[:, None]        # update acc        # 更新 acc        v = tl.load(V_block_ptr)        if fp8_v:            p = p.to(tl.float8e5)        else:            p = p.to(tl.float16)        acc = tl.dot(p, v, acc)        # update m_i and l_i        # 更新 m_i 和 l_i        m_i = m_ij        V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))        K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))    return acc, l_i, m_i# We don't run auto-tuning every time to keep the tutorial fast. Keeping# the code below and commenting out the equivalent parameters is convenient for# re-tuning.# 为了重新调整,我们不会每次都运行自动调优以保持教程的快速性。# 保留下面的代码并注释掉等效的参数是方便的做法。configs = [    triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \    for BM in [64, 128]\    for BN in [32, 64]\    for s in ([1] if is_hip() else [3, 4, 7])\    for w in [4, 8]\]def keep(conf):    BLOCK_M = conf.kwargs["BLOCK_M"]    BLOCK_N = conf.kwargs["BLOCK_N"]    if BLOCK_M * BLOCK_N = offs_n[:, None])            pT = tl.where(mask, pT, 0.0)        do = tl.load(do_ptrs)        # Compute dV.        # 计算 dV.        ppT = pT        ppT = ppT.to(tl.float16)        dv += tl.dot(ppT, do)        # D (= delta) is pre-divided by ds_scale.        # D(= delta)已经通过 ds_scale 进行了预除。        Di = tl.load(D + offs_m)        # Compute dP and dS.        # 计算 dP 和 dS.        dpT = tl.dot(v, tl.trans(do)).to(tl.float32)        dsT = pT * (dpT - Di[None, :])        dsT = dsT.to(tl.float16)        dk += tl.dot(dsT, tl.trans(qT))        # Increment pointers.        # 增加指针。        curr_m += step_m        qT_ptrs += step_m * stride_tok        do_ptrs += step_m * stride_tok    return dk, dv# the main inner-loop logic for computing dQ# 内循环的主要逻辑是计算 dQ@triton.jitdef _attn_bwd_dq(dq, q, K, V,  #                 do, m, D,                 # shared by Q/K/V/DO.                 # 由 Q/K/V/DO 共享                 stride_tok, stride_d,  #                 H, N_CTX,  #                 BLOCK_M2: tl.constexpr,  #                 BLOCK_N2: tl.constexpr,  #                 HEAD_DIM: tl.constexpr,                 # Filled in by the wrapper.                 # 由 wrapper 填充                 start_m, start_n, num_steps,  #                 MASK: tl.constexpr):    offs_m = start_m + tl.arange(0, BLOCK_M2)    offs_n = start_n + tl.arange(0, BLOCK_N2)    offs_k = tl.arange(0, HEAD_DIM)    kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d    vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d    # D (= delta) is pre-divided by ds_scale.    # D(即 delta)在使用前已经被 ds_scale 预先除以。    Di = tl.load(D + offs_m)    # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.    # BLOCK_M2 必须是 BLOCK_N2 的倍数,否则代码将无法正常工作。    tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)    curr_n = start_n    step_n = BLOCK_N2    for blk_idx in range(num_steps):        kT = tl.load(kT_ptrs)        vT = tl.load(vT_ptrs)        qk = tl.dot(q, kT)        p = tl.math.exp2(qk - m)        # Autoregressive masking.        # 自回归掩码。        if MASK:            offs_n = curr_n + tl.arange(0, BLOCK_N2)            mask = (offs_m[:, None] >= offs_n[None, :])            p = tl.where(mask, p, 0.0)        # Compute dP and dS.        # 计算 dP 和 dS.        dp = tl.dot(do, vT).to(tl.float32)        ds = p * (dp - Di[:, None])        ds = ds.to(tl.float16)        # Compute dQ.        # 计算 dQ.        # NOTE: We need to de-scale dq in the end, because kT was pre-scaled.        # 注意:我们需要在最后对 dq 进行反缩放,因为 kT 已经进行了预缩放。        dq += tl.dot(ds, tl.trans(kT))        # 增加指针。        curr_n += step_n        kT_ptrs += step_n * stride_tok        vT_ptrs += step_n * stride_tok    return dq@triton.jitdef _attn_bwd(Q, K, V, sm_scale,  #              DO,  #              DQ, DK, DV,  #              M, D,              # shared by Q/K/V/DO.              # 由 Q/K/V/DO 共享              stride_z, stride_h, stride_tok, stride_d,  #              H, N_CTX,  #              BLOCK_M1: tl.constexpr,  #              BLOCK_N1: tl.constexpr,  #              BLOCK_M2: tl.constexpr,  #              BLOCK_N2: tl.constexpr,  #              BLK_SLICE_FACTOR: tl.constexpr,  #              HEAD_DIM: tl.constexpr):    LN2: tl.constexpr = 0.6931471824645996  # = ln(2)    bhid = tl.program_id(2)    off_chz = (bhid * N_CTX).to(tl.int64)    adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64)    pid = tl.program_id(0)    # offset pointers for batch/head    # 为 batch/head 偏移指针。    Q += adj    K += adj    V += adj    DO += adj    DQ += adj    DK += adj    DV += adj    M += off_chz    D += off_chz    # load scales    # 加载系数    offs_k = tl.arange(0, HEAD_DIM)    start_n = pid * BLOCK_N1    start_m = start_n    MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR    offs_n = start_n + tl.arange(0, BLOCK_N1)    dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)    dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)    # load K and V: they stay in SRAM throughout the inner loop.    # 加载 K 和 V:它们在整个内部循环过程中保持在 SRAM 中。    k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)    v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)    num_steps = BLOCK_N1 // MASK_BLOCK_M1    dk, dv = _attn_bwd_dkdv(dk, dv,  #                            Q, k, v, sm_scale,  #                            DO,  #                            M, D,  #                            stride_tok, stride_d,  #                            H, N_CTX,  #                            MASK_BLOCK_M1, BLOCK_N1, HEAD_DIM,  #                            start_n, start_m, num_steps,  #                            MASK=True  #                            )    start_m += num_steps * MASK_BLOCK_M1    num_steps = (N_CTX - start_m) // BLOCK_M1    # Compute dK and dV for non-masked blocks.    # 计算非掩码块的 dK 和 dV。    dk, dv = _attn_bwd_dkdv(  #        dk, dv,  #        Q, k, v, sm_scale,  #        DO,  #        M, D,  #        stride_tok, stride_d,  #        H, N_CTX,  #        BLOCK_M1, BLOCK_N1, HEAD_DIM,  #        start_n, start_m, num_steps,  #        MASK=False  #    )    dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d    tl.store(dv_ptrs, dv)    # Write back dK.    # 写回 dK    dk *= sm_scale    dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d    tl.store(dk_ptrs, dk)    # THIS BLOCK DOES DQ:    # 该段计算 DQ    start_m = pid * BLOCK_M2    end_n = start_m + BLOCK_M2    MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR    offs_m = start_m + tl.arange(0, BLOCK_M2)    q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)    dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32)    do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)    m = tl.load(M + offs_m)    m = m[:, None]    # Compute dQ for masked (diagonal) blocks.    # 计算掩码(对角线)块的 dQ。    # NOTE: This code scans each row of QK^T backward (from right to left,    # but inside each call to _attn_bwd_dq, from left to right), but that's    # not due to anything important.  I just wanted to reuse the loop    # structure for dK & dV above as much as possible.    # 注意:这段代码逆向扫描每行 QK^T(从右向左,但在每次调用 _attn_bwd_dq 内部,    # 从左向右),但这并不重要。我只是希望尽可能多地重用上述 dK 和 dV 的循环结构。    num_steps = BLOCK_M2 // MASK_BLOCK_N2    dq = _attn_bwd_dq(dq, q, K, V,  #                      do, m, D,  #                      stride_tok, stride_d,  #                      H, N_CTX,  #                      BLOCK_M2, MASK_BLOCK_N2, HEAD_DIM,  #                      start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps,  #                      MASK=True  #                      )    end_n -= num_steps * MASK_BLOCK_N2    # stage 2    # 阶段 2    num_steps = end_n // BLOCK_N2    dq = _attn_bwd_dq(dq, q, K, V,  #                      do, m, D,  #                      stride_tok, stride_d,  #                      H, N_CTX,  #                      BLOCK_M2, BLOCK_N2, HEAD_DIM,  #                      start_m, end_n - num_steps * BLOCK_N2, num_steps,  #                      MASK=False  #                      )    # Write back dQ.    # 写回 dQ    dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d    dq *= LN2    tl.store(dq_ptrs, dq)class _attention(torch.autograd.Function):    @staticmethod    def forward(ctx, q, k, v, causal, sm_scale):        # shape constraints        # 形状约束        HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1]        # when v is in float8_e5m2 it is transposed.        # 当 v 在 float8_e5m2 格式下时,进行转置。        HEAD_DIM_V = v.shape[-1]        assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V        assert HEAD_DIM_K in {16, 32, 64, 128, 256}        o = torch.empty_like(q)        stage = 3 if causal else 1        extra_kern_args = {}        # Tuning for AMD target        # 为 AMD 设备调整        if is_hip():            waves_per_eu = 3 if HEAD_DIM_K 

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

Triton 并行编程 DNN计算 函数逻辑
相关文章