智源社区 01月15日
【Triton 教程】持久矩阵乘法 (Persistent Matmul)
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

Triton 是一种专为并行编程设计的语言和编译器,它基于 Python 环境,旨在高效编写自定义深度神经网络(DNN)计算内核,并在现代 GPU 硬件上实现最大吞吐量。文章展示了使用 Triton 实现矩阵乘法的持久化内核,包括朴素方法、持久化方法和基于张量内存加速器(TMA)的方法。这些内核支持半精度浮点数(FP16)和8位浮点数(FP8),但FP8仅在计算能力大于等于9.0的CUDA设备上可用。Triton 与 cuBLAS 的性能通过基准测试和质子分析器进行评估。用户可通过命令行参数灵活调整矩阵维度和迭代步骤。请注意,在共享内存较小的设备上,此教程可能会失败。

🚀Triton 是一种基于 Python 的并行编程语言和编译器,专注于为现代 GPU 硬件高效编写自定义 DNN 计算内核。

🧮该脚本展示了使用 Triton 实现矩阵乘法的多种方法,包括朴素方法、持久化方法和基于张量内存加速器(TMA)的方法,并支持FP16和FP8数据类型。

📊Triton 的性能通过基准测试与 cuBLAS 进行比较,并使用质子分析器进行评估,用户可以通过命令行参数灵活调整矩阵维度和迭代步骤。

Triton 是一种用于并行编程的语言和编译器。它旨在提供一个基于 Python 的编程环境,以高效编写自定义 DNN 计算内核,并能够在现代 GPU 硬件上以最大吞吐量运行。

该脚本展示了使用 Triton 进行矩阵乘法的持久化内核实现 (persistent kernel implementations)。包含多种矩阵乘法方法,例如基础的朴素方法 (naive)、持久化方法 (persistent) 以及基于张量内存加速器(TMA,Tensor Memory Accelerator)的方法。这些内核同时支持半精度浮点数(FP16)和 8 位浮点数(FP8)数据类型,但 FP8 的实现仅在计算能力大于等于 9.0 的 CUDA 设备上可用。

Triton 与 cuBLAS 的具体实现将会在多种各异的配置情形下开展基准测试工作,并通过质子分析器 (proton profiler) 进行评估。使用者可以通过命令行参数灵活指定矩阵的维度和迭代步骤。

请注意,当前这个教程可能会在共享内存较小的设备上失败,比如 RTX-4090。

import argparseimport torchimport tritonimport triton.language as tlimport triton.tools.experimental_descriptorimport triton.profiler as protonfrom contextlib import contextmanagerif torch.cuda.is_available():    from triton._C.libtriton import nvidia    cublas_workspace = torch.empty(32 * 1024 * 1024, device="cuda", dtype=torch.uint8)    cublas = nvidia.cublas.CublasLt(cublas_workspace)else:    cublas = Nonedef is_cuda():    return triton.runtime.driver.active.get_current_target().backend == "cuda"def supports_tma():    return is_cuda() and torch.cuda.get_device_capability()[0] >= 9def _matmul_launch_metadata(grid, kernel, args):    ret = {}    M, N, K = args["M"], args["N"], args["K"]    ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}]"    if "tiles_per_update" in args:        ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}, tiles_per_update={args['tiles_per_update']:02}]"    if "c_ptr" in args:        bytes_per_elem = args["c_ptr"].element_size()    else:        bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2    ret[f"flops{bytes_per_elem * 8}"] = 2. * M * N * K    ret["bytes"] = bytes_per_elem * (M * K + N * K + M * N)    return ret@triton.jit(launch_metadata=_matmul_launch_metadata)def matmul_kernel(a_ptr, b_ptr, c_ptr,  #                  M, N, K,  #                  stride_am, stride_ak,  #                  stride_bk, stride_bn,  #                  stride_cm, stride_cn,  #                  BLOCK_SIZE_M: tl.constexpr,  #                  BLOCK_SIZE_N: tl.constexpr,  #                  BLOCK_SIZE_K: tl.constexpr,  #                  GROUP_SIZE_M: tl.constexpr,  #                  ):    pid = tl.program_id(axis=0)    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)    num_pid_in_group = GROUP_SIZE_M * num_pid_n    group_id = pid // num_pid_in_group    first_pid_m = group_id * GROUP_SIZE_M    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)    pid_m = first_pid_m + (pid % group_size_m)    pid_n = (pid % num_pid_in_group) // group_size_m    start_m = pid_m * BLOCK_SIZE_M    start_n = pid_n * BLOCK_SIZE_N    offs_am = start_m + tl.arange(0, BLOCK_SIZE_M)    offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N)    offs_am = tl.where(offs_am 

内容中包含的图片若涉及版权问题,请及时与我们联系删除

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

Triton 并行编程 GPU 矩阵乘法 DNN
相关文章