Triton 是一种用于并行编程的语言和编译器。它旨在提供一个基于 Python 的编程环境,以高效编写自定义 DNN 计算内核,并能够在现代 GPU 硬件上以最大吞吐量运行。
import pytestimport torchimport tritonimport triton.language as tldef is_hip(): return triton.runtime.driver.active.get_current_target().backend == "hip"@triton.jitdef _attn_fwd_inner(acc, l_i, m_i, q, # K_block_ptr, V_block_ptr, # start_m, qk_scale, # BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, # STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, # N_CTX: tl.constexpr, fp8_v: tl.constexpr): # range of values handled by this stage # 这个阶段处理的数值范围 if STAGE == 1: lo, hi = 0, start_m * BLOCK_M elif STAGE == 2: lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M lo = tl.multiple_of(lo, BLOCK_M) # causal = False else: lo, hi = 0, N_CTX K_block_ptr = tl.advance(K_block_ptr, (0, lo)) V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) # loop over k, v and update accumulator # 循环遍历 k 和 v,并更新累加器。 for start_n in range(lo, hi, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- # -- 计算 qk ---- k = tl.load(K_block_ptr) qk = tl.dot(q, k) if STAGE == 2: mask = offs_m[:, None] >= (start_n + offs_n[None, :]) qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) m_ij = tl.maximum(m_i, tl.max(qk, 1)) qk -= m_ij[:, None] else: m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) qk = qk * qk_scale - m_ij[:, None] p = tl.math.exp2(qk) l_ij = tl.sum(p, 1) # -- update m_i and l_i # -- 更新 m_i 和 l_i alpha = tl.math.exp2(m_i - m_ij) l_i = l_i * alpha + l_ij # -- update output accumulator -- # -- 更新输出累加器 -- acc = acc * alpha[:, None] # update acc # 更新 acc v = tl.load(V_block_ptr) if fp8_v: p = p.to(tl.float8e5) else: p = p.to(tl.float16) acc = tl.dot(p, v, acc) # update m_i and l_i # 更新 m_i 和 l_i m_i = m_ij V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) return acc, l_i, m_i# We don't run auto-tuning every time to keep the tutorial fast. Keeping# the code below and commenting out the equivalent parameters is convenient for# re-tuning.# 为了重新调整,我们不会每次都运行自动调优以保持教程的快速性。# 保留下面的代码并注释掉等效的参数是方便的做法。configs = [ triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \ for BM in [64, 128]\ for BN in [32, 64]\ for s in ([1] if is_hip() else [3, 4, 7])\ for w in [4, 8]\]def keep(conf): BLOCK_M = conf.kwargs["BLOCK_M"] BLOCK_N = conf.kwargs["BLOCK_N"] if BLOCK_M * BLOCK_N = offs_n[:, None]) pT = tl.where(mask, pT, 0.0) do = tl.load(do_ptrs) # Compute dV. # 计算 dV. ppT = pT ppT = ppT.to(tl.float16) dv += tl.dot(ppT, do) # D (= delta) is pre-divided by ds_scale. # D(= delta)已经通过 ds_scale 进行了预除。 Di = tl.load(D + offs_m) # Compute dP and dS. # 计算 dP 和 dS. dpT = tl.dot(v, tl.trans(do)).to(tl.float32) dsT = pT * (dpT - Di[None, :]) dsT = dsT.to(tl.float16) dk += tl.dot(dsT, tl.trans(qT)) # Increment pointers. # 增加指针。 curr_m += step_m qT_ptrs += step_m * stride_tok do_ptrs += step_m * stride_tok return dk, dv# the main inner-loop logic for computing dQ# 内循环的主要逻辑是计算 dQ@triton.jitdef _attn_bwd_dq(dq, q, K, V, # do, m, D, # shared by Q/K/V/DO. # 由 Q/K/V/DO 共享 stride_tok, stride_d, # H, N_CTX, # BLOCK_M2: tl.constexpr, # BLOCK_N2: tl.constexpr, # HEAD_DIM: tl.constexpr, # Filled in by the wrapper. # 由 wrapper 填充 start_m, start_n, num_steps, # MASK: tl.constexpr): offs_m = start_m + tl.arange(0, BLOCK_M2) offs_n = start_n + tl.arange(0, BLOCK_N2) offs_k = tl.arange(0, HEAD_DIM) kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d # D (= delta) is pre-divided by ds_scale. # D(即 delta)在使用前已经被 ds_scale 预先除以。 Di = tl.load(D + offs_m) # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. # BLOCK_M2 必须是 BLOCK_N2 的倍数,否则代码将无法正常工作。 tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) curr_n = start_n step_n = BLOCK_N2 for blk_idx in range(num_steps): kT = tl.load(kT_ptrs) vT = tl.load(vT_ptrs) qk = tl.dot(q, kT) p = tl.math.exp2(qk - m) # Autoregressive masking. # 自回归掩码。 if MASK: offs_n = curr_n + tl.arange(0, BLOCK_N2) mask = (offs_m[:, None] >= offs_n[None, :]) p = tl.where(mask, p, 0.0) # Compute dP and dS. # 计算 dP 和 dS. dp = tl.dot(do, vT).to(tl.float32) ds = p * (dp - Di[:, None]) ds = ds.to(tl.float16) # Compute dQ. # 计算 dQ. # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. # 注意:我们需要在最后对 dq 进行反缩放,因为 kT 已经进行了预缩放。 dq += tl.dot(ds, tl.trans(kT)) # 增加指针。 curr_n += step_n kT_ptrs += step_n * stride_tok vT_ptrs += step_n * stride_tok return dq@triton.jitdef _attn_bwd(Q, K, V, sm_scale, # DO, # DQ, DK, DV, # M, D, # shared by Q/K/V/DO. # 由 Q/K/V/DO 共享 stride_z, stride_h, stride_tok, stride_d, # H, N_CTX, # BLOCK_M1: tl.constexpr, # BLOCK_N1: tl.constexpr, # BLOCK_M2: tl.constexpr, # BLOCK_N2: tl.constexpr, # BLK_SLICE_FACTOR: tl.constexpr, # HEAD_DIM: tl.constexpr): LN2: tl.constexpr = 0.6931471824645996 # = ln(2) bhid = tl.program_id(2) off_chz = (bhid * N_CTX).to(tl.int64) adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64) pid = tl.program_id(0) # offset pointers for batch/head # 为 batch/head 偏移指针。 Q += adj K += adj V += adj DO += adj DQ += adj DK += adj DV += adj M += off_chz D += off_chz # load scales # 加载系数 offs_k = tl.arange(0, HEAD_DIM) start_n = pid * BLOCK_N1 start_m = start_n MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR offs_n = start_n + tl.arange(0, BLOCK_N1) dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) # load K and V: they stay in SRAM throughout the inner loop. # 加载 K 和 V:它们在整个内部循环过程中保持在 SRAM 中。 k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) num_steps = BLOCK_N1 // MASK_BLOCK_M1 dk, dv = _attn_bwd_dkdv(dk, dv, # Q, k, v, sm_scale, # DO, # M, D, # stride_tok, stride_d, # H, N_CTX, # MASK_BLOCK_M1, BLOCK_N1, HEAD_DIM, # start_n, start_m, num_steps, # MASK=True # ) start_m += num_steps * MASK_BLOCK_M1 num_steps = (N_CTX - start_m) // BLOCK_M1 # Compute dK and dV for non-masked blocks. # 计算非掩码块的 dK 和 dV。 dk, dv = _attn_bwd_dkdv( # dk, dv, # Q, k, v, sm_scale, # DO, # M, D, # stride_tok, stride_d, # H, N_CTX, # BLOCK_M1, BLOCK_N1, HEAD_DIM, # start_n, start_m, num_steps, # MASK=False # ) dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d tl.store(dv_ptrs, dv) # Write back dK. # 写回 dK dk *= sm_scale dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d tl.store(dk_ptrs, dk) # THIS BLOCK DOES DQ: # 该段计算 DQ start_m = pid * BLOCK_M2 end_n = start_m + BLOCK_M2 MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR offs_m = start_m + tl.arange(0, BLOCK_M2) q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) m = tl.load(M + offs_m) m = m[:, None] # Compute dQ for masked (diagonal) blocks. # 计算掩码(对角线)块的 dQ。 # NOTE: This code scans each row of QK^T backward (from right to left, # but inside each call to _attn_bwd_dq, from left to right), but that's # not due to anything important. I just wanted to reuse the loop # structure for dK & dV above as much as possible. # 注意:这段代码逆向扫描每行 QK^T(从右向左,但在每次调用 _attn_bwd_dq 内部, # 从左向右),但这并不重要。我只是希望尽可能多地重用上述 dK 和 dV 的循环结构。 num_steps = BLOCK_M2 // MASK_BLOCK_N2 dq = _attn_bwd_dq(dq, q, K, V, # do, m, D, # stride_tok, stride_d, # H, N_CTX, # BLOCK_M2, MASK_BLOCK_N2, HEAD_DIM, # start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, # MASK=True # ) end_n -= num_steps * MASK_BLOCK_N2 # stage 2 # 阶段 2 num_steps = end_n // BLOCK_N2 dq = _attn_bwd_dq(dq, q, K, V, # do, m, D, # stride_tok, stride_d, # H, N_CTX, # BLOCK_M2, BLOCK_N2, HEAD_DIM, # start_m, end_n - num_steps * BLOCK_N2, num_steps, # MASK=False # ) # Write back dQ. # 写回 dQ dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d dq *= LN2 tl.store(dq_ptrs, dq)class _attention(torch.autograd.Function): @staticmethod def forward(ctx, q, k, v, causal, sm_scale): # shape constraints # 形状约束 HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] # when v is in float8_e5m2 it is transposed. # 当 v 在 float8_e5m2 格式下时,进行转置。 HEAD_DIM_V = v.shape[-1] assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V assert HEAD_DIM_K in {16, 32, 64, 128, 256} o = torch.empty_like(q) stage = 3 if causal else 1 extra_kern_args = {} # Tuning for AMD target # 为 AMD 设备调整 if is_hip(): waves_per_eu = 3 if HEAD_DIM_K