Triton 是一种用于并行编程的语言和编译器。它旨在提供一个基于 Python 的编程环境,以高效编写自定义 DNN 计算内核,并能够在现代 GPU 硬件上以最大吞吐量运行。
该脚本展示了使用 Triton 进行矩阵乘法的持久化内核实现 (persistent kernel implementations)。包含多种矩阵乘法方法,例如基础的朴素方法 (naive)、持久化方法 (persistent) 以及基于张量内存加速器(TMA,Tensor Memory Accelerator)的方法。这些内核同时支持半精度浮点数(FP16)和 8 位浮点数(FP8)数据类型,但 FP8 的实现仅在计算能力大于等于 9.0 的 CUDA 设备上可用。
Triton 与 cuBLAS 的具体实现将会在多种各异的配置情形下开展基准测试工作,并通过质子分析器 (proton profiler) 进行评估。使用者可以通过命令行参数灵活指定矩阵的维度和迭代步骤。
请注意,当前这个教程可能会在共享内存较小的设备上失败,比如 RTX-4090。
import argparseimport torchimport tritonimport triton.language as tlimport triton.tools.experimental_descriptorimport triton.profiler as protonfrom contextlib import contextmanagerif torch.cuda.is_available(): from triton._C.libtriton import nvidia cublas_workspace = torch.empty(32 * 1024 * 1024, device="cuda", dtype=torch.uint8) cublas = nvidia.cublas.CublasLt(cublas_workspace)else: cublas = Nonedef is_cuda(): return triton.runtime.driver.active.get_current_target().backend == "cuda"def supports_tma(): return is_cuda() and torch.cuda.get_device_capability()[0] >= 9def _matmul_launch_metadata(grid, kernel, args): ret = {} M, N, K = args["M"], args["N"], args["K"] ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}]" if "tiles_per_update" in args: ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}, tiles_per_update={args['tiles_per_update']:02}]" if "c_ptr" in args: bytes_per_elem = args["c_ptr"].element_size() else: bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2 ret[f"flops{bytes_per_elem * 8}"] = 2. * M * N * K ret["bytes"] = bytes_per_elem * (M * K + N * K + M * N) return ret@triton.jit(launch_metadata=_matmul_launch_metadata)def matmul_kernel(a_ptr, b_ptr, c_ptr, # M, N, K, # stride_am, stride_ak, # stride_bk, stride_bn, # stride_cm, stride_cn, # BLOCK_SIZE_M: tl.constexpr, # BLOCK_SIZE_N: tl.constexpr, # BLOCK_SIZE_K: tl.constexpr, # GROUP_SIZE_M: tl.constexpr, # ): pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_in_group = GROUP_SIZE_M * num_pid_n group_id = pid // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + (pid % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m start_m = pid_m * BLOCK_SIZE_M start_n = pid_n * BLOCK_SIZE_N offs_am = start_m + tl.arange(0, BLOCK_SIZE_M) offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N) offs_am = tl.where(offs_am
内容中包含的图片若涉及版权问题,请及时与我们联系删除