智源社区 2024年11月14日
【Triton 教程】层标准化
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

Triton 是一种用于并行编程的语言和编译器,旨在提供基于 Python 的编程环境,高效编写自定义 DNN 计算内核,并在现代 GPU 上实现最大吞吐量。本文以 Layer Normalization 为例,展示如何使用 Triton 编写比 PyTorch 实现更快的自定义内核。Layer Normalization 是一种用于提高序列模型或小批量神经网络性能的算子,它通过减去均值并除以标准差来标准化输入向量,然后应用可学习的线性变换。本文详细介绍了 Triton 内核的实现,包括前向传播和反向传播,并与 PyTorch 实现进行了性能比较,结果表明 Triton 内核在特定场景下可以显著提升性能。

🤔Triton是一种用于并行编程的语言和编译器,旨在优化DNN计算内核在GPU上的执行效率,提供基于Python的编程环境。

🚀Layer Normalization算子通过减去均值并除以标准差来标准化输入向量,再应用线性变换,提升序列模型或小批量神经网络性能。

💡Triton内核实现包括前向传播和反向传播两个阶段,前向传播使用融合内核加速计算,反向传播则采用并行归约策略优化梯度计算。

📊性能比较表明,Triton内核在特定场景下(例如每个特征少于64KB的输入)可以显著优于PyTorch实现,体现出Triton在优化DNN计算方面的优势。

📚Layer Normalization最早由Ba等人在2016年提出,旨在提升神经网络的性能。

Triton 是一种用于并行编程的语言和编译器。它旨在提供一个基于 Python 的编程环境,以高效编写自定义 DNN 计算内核,并能够在现代 GPU 硬件上以最大吞吐量运行。

在本教程中,你将编写一个比 PyTorch 实现运行更快的高性能层标准化 (layer normalization) 内核。

层标准化 (LayerNorm) 算子最先在 BA2016 中提出,旨在提高序列模型(例如Transformers)或小 batchsize 神经网络的性能。它以向量 x 作为输入,并生成与输入 shape 相同的向量 y 作为输出。 标准化是通过减去均值并除以 x 的标准差来实现的。 标准化后,会应用带有权重 w 和偏置 b 的可学习线性变换。

import torchimport tritonimport triton.language as tltry:    # This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it    # should not be added to extras_require in setup.py.    # 这是 https://github.com/NVIDIA/apex,不是 PyPi 的 apex,    # 所以不应该加进 setup.py 的额外依赖中    import apex    HAS_APEX = Trueexcept ModuleNotFoundError:    HAS_APEX = False@triton.jitdef _layer_norm_fwd_fused(    X,  # pointer to the input 输入指针    Y,  # pointer to the output 输出指针    W,  # pointer to the weights 权重指针    B,  # pointer to the biases 偏差指针    Mean,  # pointer to the mean 均值指针    Rstd,  # pointer to the 1/std 1/std 指针    stride,  # how much to increase the pointer when moving by 1 row 指针移动一行应该增加多少    N,  # number of columns in X X 的列数    eps,  # epsilon to avoid division by zero 用于避免除以 0 的 epsilon    BLOCK_SIZE: tl.constexpr,):    # Map the program id to the row of X and Y it should compute.    # 映射程序 id 到对应计算的 X 和 Y 的行    row = tl.program_id(0)    Y += row * stride    X += row * stride    # Compute mean    # 计算均值    mean = 0    _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)    for off in range(0, N, BLOCK_SIZE):        cols = off + tl.arange(0, BLOCK_SIZE)        a = tl.load(X + cols, mask=cols 

由于在同一批次中的所有行使用相同的权重 w 和偏差 b,它们的梯度需要累加。为了高效地执行此步骤,我们使用并行归约策略:每个内核实例将某些行的部分 w 和 b 累积到 GROUP_SIZE_M 个独立缓冲区之一中。这些缓冲区保存在 L2 缓存中,然后通过另一个函数进一步归约以计算实际的w 和 b

设输入行数 M=4 和 GROUP_SIZE_M=2,以下是 w 的并行归约策略图示(为简洁起见,省略 b):

在第一阶段,同色的 X 行共享同一个缓冲区,因此使用 lock 以确保一次只有一个内核实例写入缓冲区。在第二阶段,这些缓冲区会进一步归约以计算最终的 w 和 b。在以下实现中,第一阶段由函数 _layer_norm_bwd_dx_fused 实现,第二阶段由函数 _layer_norm_bwd_dwdb 实现。

@triton.jitdef _layer_norm_bwd_dx_fused(DX,  # pointer to the input gradient 输入梯度指针                             DY,  # pointer to the output gradient 输出梯度指针                             DW,  # pointer to the partial sum of weights gradient 权重和梯度指针                             DB,  # pointer to the partial sum of biases gradient 偏差梯度部分和指针                             X,  # pointer to the input 输入指针                             W,  # pointer to the weights 权重指针                             Mean,  # pointer to the mean 均值指针                             Rstd,  # pointer to the 1/std 1/std 指针                             Lock,  # pointer to the lock 锁指针                             stride,  # how much to increase the pointer when moving by 1 row 指针移动一行应该增加多少                             N,  # number of columns in X X 的列数                             GROUP_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):    # Map the program id to the elements of X, DX, and DY it should compute.    # 映射程序 id 到对应计算的 X, DX, DY    row = tl.program_id(0)    cols = tl.arange(0, BLOCK_SIZE_N)    mask = cols 

现在我们可以比较 Triton 内核与 PyTorch 的性能了。这里以每个特征少于 64KB 的输入为例进行讲解。具体来说,可以设置 mode: 'backward' 来进行后向传播的基准测试。

class LayerNorm(torch.autograd.Function):    @staticmethod    def forward(ctx, x, normalized_shape, weight, bias, eps):        # allocate output        # 分配输出        y = torch.empty_like(x)        # reshape input data into 2D tensor        # 将输入数据的形状改为 2D 张量        x_arg = x.reshape(-1, x.shape[-1])        M, N = x_arg.shape        mean = torch.empty((M, ), dtype=torch.float32, device=x.device)        rstd = torch.empty((M, ), dtype=torch.float32, device=x.device)        # Less than 64KB per feature: enqueue fused kernel        # 少于 64KB 每个特征:入队融合内核        MAX_FUSED_SIZE = 65536 // x.element_size()        BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))        if N > BLOCK_SIZE:            raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")        # heuristics for number of warps        # 对 warp 数量的启发算法        num_warps = min(max(BLOCK_SIZE // 256, 1), 8)        # enqueue kernel        # 入队内核        _layer_norm_fwd_fused[(M, )](  #            x_arg, y, weight, bias, mean, rstd,  #            x_arg.stride(0), N, eps,  #            BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_ctas=1)        ctx.save_for_backward(x, weight, bias, mean, rstd)        ctx.BLOCK_SIZE = BLOCK_SIZE        ctx.num_warps = num_warps        ctx.eps = eps        return y    @staticmethod    def backward(ctx, dy):        x, w, b, m, v = ctx.saved_tensors        # heuristics for amount of parallel reduction stream for DW/DB        # 计算对 DW/DB 并行规约流数量的启发算法        N = w.shape[0]        GROUP_SIZE_M = 64        if N 

[BA2016] Jimmy Lei Ba and Jamie Ryan Kiros and Geoffrey E. Hinton, “Layer Normalization”, Arxiv 2016

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

Triton Layer Normalization GPU加速 DNN 并行编程
相关文章