智源社区 01月08日
【Triton 教程】分组 GEMM
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

Triton是一种用于并行编程的语言和编译器,它基于Python环境,旨在高效编写自定义DNN计算内核,并在现代GPU硬件上实现最大吞吐量。文章重点介绍了Triton的分组GEMM内核,该内核通过启动固定数量的CTA来计算一组gemms。调度过程是静态的,并在设备上完成。该工具利用`triton.autotune`装饰器进行自动调优,并提供了多种配置选项,例如`BLOCK_SIZE_M`、`BLOCK_SIZE_N`、`BLOCK_SIZE_K`和`NUM_SM`,以优化性能。Triton力求在GPU上高效执行复杂的矩阵运算,为深度学习等领域提供强大的计算支持。

🚀Triton是一个基于Python的并行编程语言和编译器,专注于高效编写自定义DNN计算内核,以充分利用现代GPU硬件的性能。

🧮Triton的分组GEMM内核通过启动固定数量的CTA来并行计算一组GEMM,调度过程是静态的,直接在设备上执行,减少了CPU的开销。

⚙️`triton.autotune`装饰器允许自动调整内核配置,例如`BLOCK_SIZE_M`、`BLOCK_SIZE_N`、`BLOCK_SIZE_K`和`NUM_SM`,以在不同硬件上实现最佳性能。

📊内核代码通过循环遍历`group_size`,根据`group_gemm_sizes`加载每个GEMM的尺寸,并计算相应的瓦片数量,实现对不同大小矩阵的灵活处理。

Triton 是一种用于并行编程的语言和编译器。它旨在提供一个基于 Python 的编程环境,以高效编写自定义 DNN 计算内核,并能够在现代 GPU 硬件上以最大吞吐量运行。

分组 GEMM 内核通过启动固定数量的 CTA 来计算一组 gemms。调度是静态的,并且在设备上完成。

# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.## Permission is hereby granted, free of charge, to any person obtaining# a copy of this software and associated documentation files# (the "Software"), to deal in the Software without restriction,# including without limitation the rights to use, copy, modify, merge,# publish, distribute, sublicense, and/or sell copies of the Software,# and to permit persons to whom the Software is furnished to do so,# subject to the following conditions:## The above copyright notice and this permission notice shall be# included in all copies or substantial portions of the Software.## THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.import torchimport tritonimport triton.language as tl@triton.autotune(    configs=[        triton.Config({            'BLOCK_SIZE_M': 128,            'BLOCK_SIZE_N': 128,            'BLOCK_SIZE_K': 32,            'NUM_SM': 84,        }),        triton.Config({            'BLOCK_SIZE_M': 128,            'BLOCK_SIZE_N': 128,            'BLOCK_SIZE_K': 32,            'NUM_SM': 128,        }),        triton.Config({            'BLOCK_SIZE_M': 64,            'BLOCK_SIZE_N': 64,            'BLOCK_SIZE_K': 32,            'NUM_SM': 84,        }),        triton.Config({            'BLOCK_SIZE_M': 64,            'BLOCK_SIZE_N': 64,            'BLOCK_SIZE_K': 32,            'NUM_SM': 128,        }),    ],    key=['group_size'],)@triton.jitdef grouped_matmul_kernel(    # device tensor of matrices pointers    # 设备张量矩阵指针    group_a_ptrs,    group_b_ptrs,    group_c_ptrs,    # device tensor of gemm sizes. its shape is [group_size, 3]    # 设备张量的 GEMM(General Matrix Multiply)大小。其形状为 [group_size, 3]    # dim 0 is group_size, dim 1 is the values of  of each gemm    # 第 0 维是 group_size,第 1 维是每个 GEMM 的  值    group_gemm_sizes,    # device tensor of leading dimension sizes. its shape is [group_size, 3]    # 设备张量的主导维度大小。其形状为 [group_size, 3]    # dim 0 is group_size, dim 1 is the values of  of each gemm    # 第 0 维是 group_size,第 1 维是每个 GEMM 的  值    g_lds,    # number of gemms    # gemms 数量    group_size,    # number of virtual SM    # 虚拟 SM 数量    NUM_SM: tl.constexpr,    # tile sizes    # tile 大小    BLOCK_SIZE_M: tl.constexpr,    BLOCK_SIZE_N: tl.constexpr,    BLOCK_SIZE_K: tl.constexpr,):    tile_idx = tl.program_id(0)    last_problem_end = 0    for g in range(group_size):        # get the gemm size of the current problem        # 得到当前问题的 gemm 大小        gm = tl.load(group_gemm_sizes + g * 3)        gn = tl.load(group_gemm_sizes + g * 3 + 1)        gk = tl.load(group_gemm_sizes + g * 3 + 2)        num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M)        num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N)        num_tiles = num_m_tiles * num_n_tiles        # iterate through the tiles in the current gemm problem        # 迭代当前 GEMM 问题中的 tiles        while (tile_idx >= last_problem_end and tile_idx 

内容中包含的图片若涉及版权问题,请及时与我们联系删除

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

Triton 并行编程 GPU GEMM 深度学习
相关文章