Triton 是一种用于并行编程的语言和编译器。它旨在提供一个基于 Python 的编程环境,以高效编写自定义 DNN 计算内核,并能够在现代 GPU 硬件上以最大吞吐量运行。
分组 GEMM 内核通过启动固定数量的 CTA 来计算一组 gemms。调度是静态的,并且在设备上完成。
# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.## Permission is hereby granted, free of charge, to any person obtaining# a copy of this software and associated documentation files# (the "Software"), to deal in the Software without restriction,# including without limitation the rights to use, copy, modify, merge,# publish, distribute, sublicense, and/or sell copies of the Software,# and to permit persons to whom the Software is furnished to do so,# subject to the following conditions:## The above copyright notice and this permission notice shall be# included in all copies or substantial portions of the Software.## THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.import torchimport tritonimport triton.language as tl@triton.autotune( configs=[ triton.Config({ 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'NUM_SM': 84, }), triton.Config({ 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'NUM_SM': 128, }), triton.Config({ 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'NUM_SM': 84, }), triton.Config({ 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'NUM_SM': 128, }), ], key=['group_size'],)@triton.jitdef grouped_matmul_kernel( # device tensor of matrices pointers # 设备张量矩阵指针 group_a_ptrs, group_b_ptrs, group_c_ptrs, # device tensor of gemm sizes. its shape is [group_size, 3] # 设备张量的 GEMM(General Matrix Multiply)大小。其形状为 [group_size, 3] # dim 0 is group_size, dim 1 is the values of of each gemm # 第 0 维是 group_size,第 1 维是每个 GEMM 的 值 group_gemm_sizes, # device tensor of leading dimension sizes. its shape is [group_size, 3] # 设备张量的主导维度大小。其形状为 [group_size, 3] # dim 0 is group_size, dim 1 is the values of of each gemm # 第 0 维是 group_size,第 1 维是每个 GEMM 的 值 g_lds, # number of gemms # gemms 数量 group_size, # number of virtual SM # 虚拟 SM 数量 NUM_SM: tl.constexpr, # tile sizes # tile 大小 BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,): tile_idx = tl.program_id(0) last_problem_end = 0 for g in range(group_size): # get the gemm size of the current problem # 得到当前问题的 gemm 大小 gm = tl.load(group_gemm_sizes + g * 3) gn = tl.load(group_gemm_sizes + g * 3 + 1) gk = tl.load(group_gemm_sizes + g * 3 + 2) num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M) num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N) num_tiles = num_m_tiles * num_n_tiles # iterate through the tiles in the current gemm problem # 迭代当前 GEMM 问题中的 tiles while (tile_idx >= last_problem_end and tile_idx
内容中包含的图片若涉及版权问题,请及时与我们联系删除