智源社区 02月10日
【TVM教程】为 GPU 自动调度卷积层
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

本文介绍了如何使用 Apache TVM 的 auto-scheduler 为 GPU 优化深度学习模型。与 AutoTVM 依赖手动模板不同,auto-scheduler 无需任何模板,只需计算声明即可自动生成搜索空间并找到合适的调度。文章以 ResNet 中最后一个卷积层为例,展示了如何定义计算、创建搜索任务、设置搜索参数以及加载和应用最佳 schedule。通过自动调度,可以实现多级平铺、协作获取、展开和算子融合等优化,从而显著提升 GPU 上的计算性能。同时,文章还介绍了如何检查调度后的 IR、验证结果的正确性以及评估执行时间,为开发者提供了一套完整的 GPU 优化流程。

🚀 Auto-scheduler是Apache TVM中用于GPU深度学习编译的框架,它摆脱了手动模板的依赖,通过计算声明自动生成搜索空间,寻找最优调度方案。

🎯 文章以ResNet-50的最后一层卷积层为例,展示了如何通过定义卷积层的计算函数、创建搜索任务,并设置目标设备为CUDA,来启动auto-scheduler的搜索过程。

⏱️ 通过auto-scheduler实现的优化包括多级平铺、协作数据获取、循环展开和算子融合,这些优化能够显著降低调度延迟,提升GPU的计算效率,实验结果表明,优化后的执行时间明显缩短。

💾 auto-scheduler会将所有的测试记录存储到json文件中,方便用户重新应用搜索结果、恢复搜索过程以及进行进一步的分析和调试,为开发者提供了便利。

🐍 通过加载最佳调度方案,并打印等效的Python调度API和CUDA源代码,用户可以深入了解auto-scheduler的行为,并进行定制化的调试和优化。

Apache TVM 是一个深度的深度学习编译框架,适用于 CPU、GPU 和各种机器学习加速芯片。更多 TVM 中文文档可访问 →https://tvm.hyper.ai/

作者:Lianmin Zheng, Chengfan Jia

本文介绍如何为 GPU 使用 auto-scheduler。

与 AutoTVM 不同,AutoTVM 依赖手动模板来定义搜索空间,而 auto-scheduler 不需要任何模板。用户只需编写计算声明,无需任何调度命令或模板。auto-scheduler 可以自动生成一个大的搜索空间,并在空间中找到合适的调度。

首先定义卷积层的计算,该函数返回输入/输出张量列表,从这些张量中,auto-scheduler 可以得到整个计算图。

@auto_scheduler.register_workloaddef conv2d_layer(N, H, W, CO, CI, KH, KW, stride, padding):    data = te.placeholder((N, CI, H, W), name="data")    kernel = te.placeholder((CO, CI, KH, KW), name="kernel")    bias = te.placeholder((1, CO, 1, 1), name="bias")    conv = topi.nn.conv2d_nchw(data, kernel, stride, padding, dilation=1, out_dtype="float32")    out = topi.nn.relu(conv + bias)    return [data, kernel, bias, out]

然后为 ResNet 中的最后一个卷积层创建一个搜索任务。

target = tvm.target.Target("cuda")# 使用 ResNet-50 中的最后一层N, H, W, CO, CI, KH, KW, strides, padding = 1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)task = auto_scheduler.SearchTask(    func=conv2d_layer, args=(N, H, W, CO, CI, KH, KW, strides, padding), target=target)# 检查计算图print("Computational DAG:")print(task.compute_dag)
Computational DAG:data = PLACEHOLDER [1, 512, 7, 7]pad_temp(i0, i1, i2, i3) = tir.if_then_else(((((i2 >= 1) && (i2 = 1)) && (i3 

接下来为 auto-scheduler 设置参数,它们主要指定在搜索过程中如何进行测试。

准备好所有输入后,可以开始搜索,让 auto-scheduler 发挥作用。经过一些测试试验后,可以从日志文件中加载最佳 schedule 并进行应用。

在自动调度后降低调度以查看 IR,auto-scheduler 正确执行优化,包括多级平铺、协作获取、展开和算子融合。

Lowered TIR:@main = primfn(data_1: handle, kernel_1: handle, bias_1: handle, compute_1: handle) -> ()  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}  buffers = {data: Buffer(data_2: Pointer(float32), float32, [25088], []),             kernel: Buffer(kernel_2: Pointer(float32), float32, [2359296], []),             bias: Buffer(bias_2: Pointer(float32), float32, [512], []),             compute: Buffer(compute_2: Pointer(float32), float32, [25088], [])}  buffer_map = {data_1: data, kernel_1: kernel, bias_1: bias, compute_1: compute}  preflattened_buffer_map = {data_1: data_3: Buffer(data_2, float32, [1, 512, 7, 7], []), kernel_1: kernel_3: Buffer(kernel_2, float32, [512, 512, 3, 3], []), bias_1: bias_3: Buffer(bias_2, float32, [1, 512, 1, 1], []), compute_1: compute_3: Buffer(compute_2, float32, [1, 512, 7, 7], [])} {  attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = 28;  allocate(conv2d_nchw: Pointer(local float32), float32, [14]), storage_scope = local;  allocate(pad_temp.shared: Pointer(shared float32), float32, [72]), storage_scope = shared;  allocate(kernel.shared: Pointer(shared float32), float32, [3072]), storage_scope = shared;  attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 64 {    conv2d_nchw_1: Buffer(conv2d_nchw, float32, [14], [], scope="local", align=32)[0] = 0f32    conv2d_nchw_1[1] = 0f32    conv2d_nchw_1[2] = 0f32    conv2d_nchw_1[3] = 0f32    conv2d_nchw_1[4] = 0f32    conv2d_nchw_1[5] = 0f32    conv2d_nchw_1[6] = 0f32    conv2d_nchw_1[7] = 0f32    conv2d_nchw_1[8] = 0f32    conv2d_nchw_1[9] = 0f32    conv2d_nchw_1[10] = 0f32    conv2d_nchw_1[11] = 0f32    conv2d_nchw_1[12] = 0f32    conv2d_nchw_1[13] = 0f32    for (rc.outer.outer: int32, 0, 64) {      for (ry.outer.outer: int32, 0, 3) {        let cse_var_2: int32 = (rc.outer.outer*72)        let cse_var_1: int32 = (ry.outer.outer*3)         {          attr [IterVar(threadIdx.x_1: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 64 {            if @tir.likely((threadIdx.x_1 
func = tvm.build(sch, args, target)# 检查正确性data_np = np.random.uniform(size=(N, CI, H, W)).astype(np.float32)weight_np = np.random.uniform(size=(CO, CI, KH, KW)).astype(np.float32)bias_np = np.random.uniform(size=(1, CO, 1, 1)).astype(np.float32)conv_np = conv2d_nchw_python(data_np, weight_np, strides, padding)out_np = np.maximum(conv_np + bias_np, 0.0)dev = tvm.cuda()data_tvm = tvm.nd.array(data_np, device=dev)weight_tvm = tvm.nd.array(weight_np, device=dev)bias_tvm = tvm.nd.array(bias_np, device=dev)out_tvm = tvm.nd.empty(out_np.shape, device=dev)func(data_tvm, weight_tvm, bias_tvm, out_tvm)# 检查结果np.testing.assert_allclose(out_np, out_tvm.numpy(), rtol=1e-3)# 评估执行时间evaluator = func.time_evaluator(func.entry_name, dev, min_repeat_ms=500)print(    "Execution time of this operator: %.3f ms"    % (np.median(evaluator(data_tvm, weight_tvm, bias_tvm, out_tvm).results) * 1000))

在搜索过程中,所有测试记录(可用于重新应用搜索结果、恢复搜索和执行其他分析)都转储到记录文件「conv2d.json」中。

下面示例中从文件中加载最佳调度,打印等效的 Python 调度 API 和 CUDA 源代码(可用于调试和分析自动调度程序的行为)。

Equivalent python schedule:pad_temp_i0, pad_temp_i1, pad_temp_i2, pad_temp_i3 = tuple(pad_temp.op.axis) + tuple(pad_temp.op.reduce_axis)conv2d_nchw_nn, conv2d_nchw_ff, conv2d_nchw_yy, conv2d_nchw_xx, conv2d_nchw_rc, conv2d_nchw_ry, conv2d_nchw_rx = tuple(conv2d_nchw.op.axis) + tuple(conv2d_nchw.op.reduce_axis)T_add_ax0, T_add_ax1, T_add_ax2, T_add_ax3 = tuple(T_add.op.axis) + tuple(T_add.op.reduce_axis)compute_i0, compute_i1, compute_i2, compute_i3 = tuple(compute.op.axis) + tuple(compute.op.reduce_axis)s[T_add].compute_inline()conv2d_nchw_nn_o_i, conv2d_nchw_nn_i = s[conv2d_nchw].split(conv2d_nchw_nn, factor=1)conv2d_nchw_nn_o_o_i, conv2d_nchw_nn_o_i = s[conv2d_nchw].split(conv2d_nchw_nn_o_i, factor=1)conv2d_nchw_nn_o_o_o_i, conv2d_nchw_nn_o_o_i = s[conv2d_nchw].split(conv2d_nchw_nn_o_o_i, factor=1)conv2d_nchw_nn_o_o_o_o, conv2d_nchw_nn_o_o_o_i = s[conv2d_nchw].split(conv2d_nchw_nn_o_o_o_i, factor=1)conv2d_nchw_ff_o_i, conv2d_nchw_ff_i = s[conv2d_nchw].split(conv2d_nchw_ff, factor=1)conv2d_nchw_ff_o_o_i, conv2d_nchw_ff_o_i = s[conv2d_nchw].split(conv2d_nchw_ff_o_i, factor=2)conv2d_nchw_ff_o_o_o_i, conv2d_nchw_ff_o_o_i = s[conv2d_nchw].split(conv2d_nchw_ff_o_o_i, factor=64)conv2d_nchw_ff_o_o_o_o, conv2d_nchw_ff_o_o_o_i = s[conv2d_nchw].split(conv2d_nchw_ff_o_o_o_i, factor=1)conv2d_nchw_yy_o_i, conv2d_nchw_yy_i = s[conv2d_nchw].split(conv2d_nchw_yy, factor=1)conv2d_nchw_yy_o_o_i, conv2d_nchw_yy_o_i = s[conv2d_nchw].split(conv2d_nchw_yy_o_i, factor=1)conv2d_nchw_yy_o_o_o_i, conv2d_nchw_yy_o_o_i = s[conv2d_nchw].split(conv2d_nchw_yy_o_o_i, factor=1)conv2d_nchw_yy_o_o_o_o, conv2d_nchw_yy_o_o_o_i = s[conv2d_nchw].split(conv2d_nchw_yy_o_o_o_i, factor=1)conv2d_nchw_xx_o_i, conv2d_nchw_xx_i = s[conv2d_nchw].split(conv2d_nchw_xx, factor=1)conv2d_nchw_xx_o_o_i, conv2d_nchw_xx_o_i = s[conv2d_nchw].split(conv2d_nchw_xx_o_i, factor=7)conv2d_nchw_xx_o_o_o_i, conv2d_nchw_xx_o_o_i = s[conv2d_nchw].split(conv2d_nchw_xx_o_o_i, factor=1)conv2d_nchw_xx_o_o_o_o, conv2d_nchw_xx_o_o_o_i = s[conv2d_nchw].split(conv2d_nchw_xx_o_o_o_i, factor=1)conv2d_nchw_rc_o_i, conv2d_nchw_rc_i = s[conv2d_nchw].split(conv2d_nchw_rc, factor=2)conv2d_nchw_rc_o_o, conv2d_nchw_rc_o_i = s[conv2d_nchw].split(conv2d_nchw_rc_o_i, factor=4)conv2d_nchw_ry_o_i, conv2d_nchw_ry_i = s[conv2d_nchw].split(conv2d_nchw_ry, factor=1)conv2d_nchw_ry_o_o, conv2d_nchw_ry_o_i = s[conv2d_nchw].split(conv2d_nchw_ry_o_i, factor=1)conv2d_nchw_rx_o_i, conv2d_nchw_rx_i = s[conv2d_nchw].split(conv2d_nchw_rx, factor=1)conv2d_nchw_rx_o_o, conv2d_nchw_rx_o_i = s[conv2d_nchw].split(conv2d_nchw_rx_o_i, factor=3)s[conv2d_nchw].reorder(conv2d_nchw_nn_o_o_o_o, conv2d_nchw_ff_o_o_o_o, conv2d_nchw_yy_o_o_o_o, conv2d_nchw_xx_o_o_o_o, conv2d_nchw_nn_o_o_o_i, conv2d_nchw_ff_o_o_o_i, conv2d_nchw_yy_o_o_o_i, conv2d_nchw_xx_o_o_o_i, conv2d_nchw_nn_o_o_i, conv2d_nchw_ff_o_o_i, conv2d_nchw_yy_o_o_i, conv2d_nchw_xx_o_o_i, conv2d_nchw_rc_o_o, conv2d_nchw_ry_o_o, conv2d_nchw_rx_o_o, conv2d_nchw_rc_o_i, conv2d_nchw_ry_o_i, conv2d_nchw_rx_o_i, conv2d_nchw_nn_o_i, conv2d_nchw_ff_o_i, conv2d_nchw_yy_o_i, conv2d_nchw_xx_o_i, conv2d_nchw_rc_i, conv2d_nchw_ry_i, conv2d_nchw_rx_i, conv2d_nchw_nn_i, conv2d_nchw_ff_i, conv2d_nchw_yy_i, conv2d_nchw_xx_i)compute_i0_o_i, compute_i0_i = s[compute].split(compute_i0, factor=1)compute_i0_o_o_i, compute_i0_o_i = s[compute].split(compute_i0_o_i, factor=1)compute_i0_o_o_o, compute_i0_o_o_i = s[compute].split(compute_i0_o_o_i, factor=1)compute_i1_o_i, compute_i1_i = s[compute].split(compute_i1, factor=2)compute_i1_o_o_i, compute_i1_o_i = s[compute].split(compute_i1_o_i, factor=64)compute_i1_o_o_o, compute_i1_o_o_i = s[compute].split(compute_i1_o_o_i, factor=1)compute_i2_o_i, compute_i2_i = s[compute].split(compute_i2, factor=1)compute_i2_o_o_i, compute_i2_o_i = s[compute].split(compute_i2_o_i, factor=1)compute_i2_o_o_o, compute_i2_o_o_i = s[compute].split(compute_i2_o_o_i, factor=1)compute_i3_o_i, compute_i3_i = s[compute].split(compute_i3, factor=7)compute_i3_o_o_i, compute_i3_o_i = s[compute].split(compute_i3_o_i, factor=1)compute_i3_o_o_o, compute_i3_o_o_i = s[compute].split(compute_i3_o_o_i, factor=1)s[compute].reorder(compute_i0_o_o_o, compute_i1_o_o_o, compute_i2_o_o_o, compute_i3_o_o_o, compute_i0_o_o_i, compute_i1_o_o_i, compute_i2_o_o_i, compute_i3_o_o_i, compute_i0_o_i, compute_i1_o_i, compute_i2_o_i, compute_i3_o_i, compute_i0_i, compute_i1_i, compute_i2_i, compute_i3_i)s[conv2d_nchw].compute_at(s[compute], compute_i3_o_i)kernel_shared = s.cache_read(kernel, "shared", [conv2d_nchw])kernel_shared_ax0, kernel_shared_ax1, kernel_shared_ax2, kernel_shared_ax3 = tuple(kernel_shared.op.axis)s[kernel_shared].compute_at(s[conv2d_nchw], conv2d_nchw_rx_o_o)pad_temp_shared = s.cache_read(pad_temp, "shared", [conv2d_nchw])pad_temp_shared_ax0, pad_temp_shared_ax1, pad_temp_shared_ax2, pad_temp_shared_ax3 = tuple(pad_temp_shared.op.axis)s[pad_temp_shared].compute_at(s[conv2d_nchw], conv2d_nchw_rx_o_o)s[pad_temp].compute_inline()compute_i0_o_o_o_i1_o_o_o_fused_i2_o_o_o_fused_i3_o_o_o_fused = s[compute].fuse(compute_i0_o_o_o, compute_i1_o_o_o, compute_i2_o_o_o, compute_i3_o_o_o)s[compute].bind(compute_i0_o_o_o_i1_o_o_o_fused_i2_o_o_o_fused_i3_o_o_o_fused, te.thread_axis("blockIdx.x"))compute_i0_o_o_i_i1_o_o_i_fused_i2_o_o_i_fused_i3_o_o_i_fused = s[compute].fuse(compute_i0_o_o_i, compute_i1_o_o_i, compute_i2_o_o_i, compute_i3_o_o_i)s[compute].bind(compute_i0_o_o_i_i1_o_o_i_fused_i2_o_o_i_fused_i3_o_o_i_fused, te.thread_axis("vthread"))compute_i0_o_i_i1_o_i_fused_i2_o_i_fused_i3_o_i_fused = s[compute].fuse(compute_i0_o_i, compute_i1_o_i, compute_i2_o_i, compute_i3_o_i)s[compute].bind(compute_i0_o_i_i1_o_i_fused_i2_o_i_fused_i3_o_i_fused, te.thread_axis("threadIdx.x"))kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused = s[kernel_shared].fuse(kernel_shared_ax0, kernel_shared_ax1, kernel_shared_ax2, kernel_shared_ax3)kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o, kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused_i = s[kernel_shared].split(kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused, factor=1)s[kernel_shared].vectorize(kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused_i)kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_o, kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_i = s[kernel_shared].split(kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o, factor=64)s[kernel_shared].bind(kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_i, te.thread_axis("threadIdx.x"))pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused = s[pad_temp_shared].fuse(pad_temp_shared_ax0, pad_temp_shared_ax1, pad_temp_shared_ax2, pad_temp_shared_ax3)pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o, pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_i = s[pad_temp_shared].split(pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused, factor=4)s[pad_temp_shared].vectorize(pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_i)pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_o, pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_i = s[pad_temp_shared].split(pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o, factor=64)s[pad_temp_shared].bind(pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_i, te.thread_axis("threadIdx.x"))s[conv2d_nchw].pragma(conv2d_nchw_nn_o_o_o_o, "auto_unroll_max_step", 512)s[conv2d_nchw].pragma(conv2d_nchw_nn_o_o_o_o, "unroll_explicit", True)CUDA source code:#ifdef _WIN32  using uint = unsigned int;  using uchar = unsigned char;  using ushort = unsigned short;  using int64_t = long long;  using uint64_t = unsigned long long;#else  #define uint unsigned int  #define uchar unsigned char  #define ushort unsigned short  #define int64_t long long  #define uint64_t unsigned long long#endifextern "C" __global__ void __launch_bounds__(64) default_function_kernel0(float* __restrict__ data, float* __restrict__ kernel, float* __restrict__ compute, float* __restrict__ bias) {  float conv2d_nchw[14];  __shared__ float pad_temp_shared[72];  __shared__ float kernel_shared[3072];  conv2d_nchw[0] = 0.000000e+00f;  conv2d_nchw[1] = 0.000000e+00f;  conv2d_nchw[2] = 0.000000e+00f;  conv2d_nchw[3] = 0.000000e+00f;  conv2d_nchw[4] = 0.000000e+00f;  conv2d_nchw[5] = 0.000000e+00f;  conv2d_nchw[6] = 0.000000e+00f;  conv2d_nchw[7] = 0.000000e+00f;  conv2d_nchw[8] = 0.000000e+00f;  conv2d_nchw[9] = 0.000000e+00f;  conv2d_nchw[10] = 0.000000e+00f;  conv2d_nchw[11] = 0.000000e+00f;  conv2d_nchw[12] = 0.000000e+00f;  conv2d_nchw[13] = 0.000000e+00f;  for (int rc_outer_outer = 0; rc_outer_outer 

以下是恢复搜索的例子。这种情况下要自行创建搜索策略和 cost 模型,并通过日志文件恢复搜索策略和 cost 模型的状态。以下示例中,我们恢复状态,并多进行 5 次训练。

def resume_search(task, log_file):    print("Resume search:")    cost_model = auto_scheduler.XGBModel()    cost_model.update_from_file(log_file)    search_policy = auto_scheduler.SketchPolicy(        task, cost_model, init_search_callbacks=[auto_scheduler.PreloadMeasuredStates(log_file)]    )    measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300)    tune_option = auto_scheduler.TuningOptions(        num_measure_trials=5,        runner=measure_ctx.runner,        measure_callbacks=[auto_scheduler.RecordToFile(log_file)],    )    task.tune(tune_option, search_policy=search_policy)    # 终止测试过程    del measure_ctx# 我们不再在服务器上运行调优,因为太耗时间了# 去掉下行代码的注释自行运行# resume_search(task, log_file)

内容中包含的图片若涉及版权问题,请及时与我们联系删除

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

Apache TVM Auto-scheduler GPU优化 深度学习编译 CUDA
相关文章