Triton 是一种用于并行编程的语言和编译器。它旨在提供一个基于 Python 的编程环境,以高效编写自定义 DNN 计算内核,并能够在现代 GPU 硬件上以最大吞吐量运行。
在本教程中,你将编写一个比 PyTorch 实现运行更快的高性能层标准化 (layer normalization) 内核。
层标准化 (LayerNorm) 算子最先在 BA2016 中提出,旨在提高序列模型(例如Transformers)或小 batchsize 神经网络的性能。它以向量 x 作为输入,并生成与输入 shape 相同的向量 y 作为输出。 标准化是通过减去均值并除以 x 的标准差来实现的。 标准化后,会应用带有权重 w 和偏置 b 的可学习线性变换。
import torchimport tritonimport triton.language as tltry: # This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it # should not be added to extras_require in setup.py. # 这是 https://github.com/NVIDIA/apex,不是 PyPi 的 apex, # 所以不应该加进 setup.py 的额外依赖中 import apex HAS_APEX = Trueexcept ModuleNotFoundError: HAS_APEX = False@triton.jitdef _layer_norm_fwd_fused( X, # pointer to the input 输入指针 Y, # pointer to the output 输出指针 W, # pointer to the weights 权重指针 B, # pointer to the biases 偏差指针 Mean, # pointer to the mean 均值指针 Rstd, # pointer to the 1/std 1/std 指针 stride, # how much to increase the pointer when moving by 1 row 指针移动一行应该增加多少 N, # number of columns in X X 的列数 eps, # epsilon to avoid division by zero 用于避免除以 0 的 epsilon BLOCK_SIZE: tl.constexpr,): # Map the program id to the row of X and Y it should compute. # 映射程序 id 到对应计算的 X 和 Y 的行 row = tl.program_id(0) Y += row * stride X += row * stride # Compute mean # 计算均值 mean = 0 _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) for off in range(0, N, BLOCK_SIZE): cols = off + tl.arange(0, BLOCK_SIZE) a = tl.load(X + cols, mask=cols
由于在同一批次中的所有行使用相同的权重 w 和偏差 b,它们的梯度需要累加。为了高效地执行此步骤,我们使用并行归约策略:每个内核实例将某些行的部分 ∇w 和 ∇b 累积到 GROUP_SIZE_M 个独立缓冲区之一中。这些缓冲区保存在 L2 缓存中,然后通过另一个函数进一步归约以计算实际的∇w 和 ∇b。
设输入行数 M=4 和 GROUP_SIZE_M=2,以下是 ∇w 的并行归约策略图示(为简洁起见,省略 ∇b):
在第一阶段,同色的 X 行共享同一个缓冲区,因此使用 lock 以确保一次只有一个内核实例写入缓冲区。在第二阶段,这些缓冲区会进一步归约以计算最终的 ∇w 和 ∇b。在以下实现中,第一阶段由函数 _layer_norm_bwd_dx_fused 实现,第二阶段由函数 _layer_norm_bwd_dwdb 实现。
@triton.jitdef _layer_norm_bwd_dx_fused(DX, # pointer to the input gradient 输入梯度指针 DY, # pointer to the output gradient 输出梯度指针 DW, # pointer to the partial sum of weights gradient 权重和梯度指针 DB, # pointer to the partial sum of biases gradient 偏差梯度部分和指针 X, # pointer to the input 输入指针 W, # pointer to the weights 权重指针 Mean, # pointer to the mean 均值指针 Rstd, # pointer to the 1/std 1/std 指针 Lock, # pointer to the lock 锁指针 stride, # how much to increase the pointer when moving by 1 row 指针移动一行应该增加多少 N, # number of columns in X X 的列数 GROUP_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): # Map the program id to the elements of X, DX, and DY it should compute. # 映射程序 id 到对应计算的 X, DX, DY row = tl.program_id(0) cols = tl.arange(0, BLOCK_SIZE_N) mask = cols
现在我们可以比较 Triton 内核与 PyTorch 的性能了。这里以每个特征少于 64KB 的输入为例进行讲解。具体来说,可以设置 mode: 'backward' 来进行后向传播的基准测试。
class LayerNorm(torch.autograd.Function): @staticmethod def forward(ctx, x, normalized_shape, weight, bias, eps): # allocate output # 分配输出 y = torch.empty_like(x) # reshape input data into 2D tensor # 将输入数据的形状改为 2D 张量 x_arg = x.reshape(-1, x.shape[-1]) M, N = x_arg.shape mean = torch.empty((M, ), dtype=torch.float32, device=x.device) rstd = torch.empty((M, ), dtype=torch.float32, device=x.device) # Less than 64KB per feature: enqueue fused kernel # 少于 64KB 每个特征:入队融合内核 MAX_FUSED_SIZE = 65536 // x.element_size() BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) if N > BLOCK_SIZE: raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") # heuristics for number of warps # 对 warp 数量的启发算法 num_warps = min(max(BLOCK_SIZE // 256, 1), 8) # enqueue kernel # 入队内核 _layer_norm_fwd_fused[(M, )]( # x_arg, y, weight, bias, mean, rstd, # x_arg.stride(0), N, eps, # BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_ctas=1) ctx.save_for_backward(x, weight, bias, mean, rstd) ctx.BLOCK_SIZE = BLOCK_SIZE ctx.num_warps = num_warps ctx.eps = eps return y @staticmethod def backward(ctx, dy): x, w, b, m, v = ctx.saved_tensors # heuristics for amount of parallel reduction stream for DW/DB # 计算对 DW/DB 并行规约流数量的启发算法 N = w.shape[0] GROUP_SIZE_M = 64 if N
[BA2016] Jimmy Lei Ba and Jamie Ryan Kiros and Geoffrey E. Hinton, “Layer Normalization”, Arxiv 2016