Apache TVM 是一个深度的深度学习编译框架,适用于 CPU、GPU 和各种机器学习加速芯片。更多 TVM 中文文档可访问 →https://tvm.hyper.ai/
作者:Lianmin Zheng, Chengfan Jia
本文介绍如何为 GPU 使用 auto-scheduler。
与 AutoTVM 不同,AutoTVM 依赖手动模板来定义搜索空间,而 auto-scheduler 不需要任何模板。用户只需编写计算声明,无需任何调度命令或模板。auto-scheduler 可以自动生成一个大的搜索空间,并在空间中找到合适的调度。
首先定义卷积层的计算,该函数返回输入/输出张量列表,从这些张量中,auto-scheduler 可以得到整个计算图。
@auto_scheduler.register_workloaddef conv2d_layer(N, H, W, CO, CI, KH, KW, stride, padding): data = te.placeholder((N, CI, H, W), name="data") kernel = te.placeholder((CO, CI, KH, KW), name="kernel") bias = te.placeholder((1, CO, 1, 1), name="bias") conv = topi.nn.conv2d_nchw(data, kernel, stride, padding, dilation=1, out_dtype="float32") out = topi.nn.relu(conv + bias) return [data, kernel, bias, out]
然后为 ResNet 中的最后一个卷积层创建一个搜索任务。
target = tvm.target.Target("cuda")# 使用 ResNet-50 中的最后一层N, H, W, CO, CI, KH, KW, strides, padding = 1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)task = auto_scheduler.SearchTask( func=conv2d_layer, args=(N, H, W, CO, CI, KH, KW, strides, padding), target=target)# 检查计算图print("Computational DAG:")print(task.compute_dag)
Computational DAG:data = PLACEHOLDER [1, 512, 7, 7]pad_temp(i0, i1, i2, i3) = tir.if_then_else(((((i2 >= 1) && (i2 = 1)) && (i3
接下来为 auto-scheduler 设置参数,它们主要指定在搜索过程中如何进行测试。
准备好所有输入后,可以开始搜索,让 auto-scheduler 发挥作用。经过一些测试试验后,可以从日志文件中加载最佳 schedule 并进行应用。
在自动调度后降低调度以查看 IR,auto-scheduler 正确执行优化,包括多级平铺、协作获取、展开和算子融合。
Lowered TIR:@main = primfn(data_1: handle, kernel_1: handle, bias_1: handle, compute_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {data: Buffer(data_2: Pointer(float32), float32, [25088], []), kernel: Buffer(kernel_2: Pointer(float32), float32, [2359296], []), bias: Buffer(bias_2: Pointer(float32), float32, [512], []), compute: Buffer(compute_2: Pointer(float32), float32, [25088], [])} buffer_map = {data_1: data, kernel_1: kernel, bias_1: bias, compute_1: compute} preflattened_buffer_map = {data_1: data_3: Buffer(data_2, float32, [1, 512, 7, 7], []), kernel_1: kernel_3: Buffer(kernel_2, float32, [512, 512, 3, 3], []), bias_1: bias_3: Buffer(bias_2, float32, [1, 512, 1, 1], []), compute_1: compute_3: Buffer(compute_2, float32, [1, 512, 7, 7], [])} { attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = 28; allocate(conv2d_nchw: Pointer(local float32), float32, [14]), storage_scope = local; allocate(pad_temp.shared: Pointer(shared float32), float32, [72]), storage_scope = shared; allocate(kernel.shared: Pointer(shared float32), float32, [3072]), storage_scope = shared; attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 64 { conv2d_nchw_1: Buffer(conv2d_nchw, float32, [14], [], scope="local", align=32)[0] = 0f32 conv2d_nchw_1[1] = 0f32 conv2d_nchw_1[2] = 0f32 conv2d_nchw_1[3] = 0f32 conv2d_nchw_1[4] = 0f32 conv2d_nchw_1[5] = 0f32 conv2d_nchw_1[6] = 0f32 conv2d_nchw_1[7] = 0f32 conv2d_nchw_1[8] = 0f32 conv2d_nchw_1[9] = 0f32 conv2d_nchw_1[10] = 0f32 conv2d_nchw_1[11] = 0f32 conv2d_nchw_1[12] = 0f32 conv2d_nchw_1[13] = 0f32 for (rc.outer.outer: int32, 0, 64) { for (ry.outer.outer: int32, 0, 3) { let cse_var_2: int32 = (rc.outer.outer*72) let cse_var_1: int32 = (ry.outer.outer*3) { attr [IterVar(threadIdx.x_1: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 64 { if @tir.likely((threadIdx.x_1
func = tvm.build(sch, args, target)# 检查正确性data_np = np.random.uniform(size=(N, CI, H, W)).astype(np.float32)weight_np = np.random.uniform(size=(CO, CI, KH, KW)).astype(np.float32)bias_np = np.random.uniform(size=(1, CO, 1, 1)).astype(np.float32)conv_np = conv2d_nchw_python(data_np, weight_np, strides, padding)out_np = np.maximum(conv_np + bias_np, 0.0)dev = tvm.cuda()data_tvm = tvm.nd.array(data_np, device=dev)weight_tvm = tvm.nd.array(weight_np, device=dev)bias_tvm = tvm.nd.array(bias_np, device=dev)out_tvm = tvm.nd.empty(out_np.shape, device=dev)func(data_tvm, weight_tvm, bias_tvm, out_tvm)# 检查结果np.testing.assert_allclose(out_np, out_tvm.numpy(), rtol=1e-3)# 评估执行时间evaluator = func.time_evaluator(func.entry_name, dev, min_repeat_ms=500)print( "Execution time of this operator: %.3f ms" % (np.median(evaluator(data_tvm, weight_tvm, bias_tvm, out_tvm).results) * 1000))
在搜索过程中,所有测试记录(可用于重新应用搜索结果、恢复搜索和执行其他分析)都转储到记录文件「conv2d.json」中。
下面示例中从文件中加载最佳调度,打印等效的 Python 调度 API 和 CUDA 源代码(可用于调试和分析自动调度程序的行为)。
Equivalent python schedule:pad_temp_i0, pad_temp_i1, pad_temp_i2, pad_temp_i3 = tuple(pad_temp.op.axis) + tuple(pad_temp.op.reduce_axis)conv2d_nchw_nn, conv2d_nchw_ff, conv2d_nchw_yy, conv2d_nchw_xx, conv2d_nchw_rc, conv2d_nchw_ry, conv2d_nchw_rx = tuple(conv2d_nchw.op.axis) + tuple(conv2d_nchw.op.reduce_axis)T_add_ax0, T_add_ax1, T_add_ax2, T_add_ax3 = tuple(T_add.op.axis) + tuple(T_add.op.reduce_axis)compute_i0, compute_i1, compute_i2, compute_i3 = tuple(compute.op.axis) + tuple(compute.op.reduce_axis)s[T_add].compute_inline()conv2d_nchw_nn_o_i, conv2d_nchw_nn_i = s[conv2d_nchw].split(conv2d_nchw_nn, factor=1)conv2d_nchw_nn_o_o_i, conv2d_nchw_nn_o_i = s[conv2d_nchw].split(conv2d_nchw_nn_o_i, factor=1)conv2d_nchw_nn_o_o_o_i, conv2d_nchw_nn_o_o_i = s[conv2d_nchw].split(conv2d_nchw_nn_o_o_i, factor=1)conv2d_nchw_nn_o_o_o_o, conv2d_nchw_nn_o_o_o_i = s[conv2d_nchw].split(conv2d_nchw_nn_o_o_o_i, factor=1)conv2d_nchw_ff_o_i, conv2d_nchw_ff_i = s[conv2d_nchw].split(conv2d_nchw_ff, factor=1)conv2d_nchw_ff_o_o_i, conv2d_nchw_ff_o_i = s[conv2d_nchw].split(conv2d_nchw_ff_o_i, factor=2)conv2d_nchw_ff_o_o_o_i, conv2d_nchw_ff_o_o_i = s[conv2d_nchw].split(conv2d_nchw_ff_o_o_i, factor=64)conv2d_nchw_ff_o_o_o_o, conv2d_nchw_ff_o_o_o_i = s[conv2d_nchw].split(conv2d_nchw_ff_o_o_o_i, factor=1)conv2d_nchw_yy_o_i, conv2d_nchw_yy_i = s[conv2d_nchw].split(conv2d_nchw_yy, factor=1)conv2d_nchw_yy_o_o_i, conv2d_nchw_yy_o_i = s[conv2d_nchw].split(conv2d_nchw_yy_o_i, factor=1)conv2d_nchw_yy_o_o_o_i, conv2d_nchw_yy_o_o_i = s[conv2d_nchw].split(conv2d_nchw_yy_o_o_i, factor=1)conv2d_nchw_yy_o_o_o_o, conv2d_nchw_yy_o_o_o_i = s[conv2d_nchw].split(conv2d_nchw_yy_o_o_o_i, factor=1)conv2d_nchw_xx_o_i, conv2d_nchw_xx_i = s[conv2d_nchw].split(conv2d_nchw_xx, factor=1)conv2d_nchw_xx_o_o_i, conv2d_nchw_xx_o_i = s[conv2d_nchw].split(conv2d_nchw_xx_o_i, factor=7)conv2d_nchw_xx_o_o_o_i, conv2d_nchw_xx_o_o_i = s[conv2d_nchw].split(conv2d_nchw_xx_o_o_i, factor=1)conv2d_nchw_xx_o_o_o_o, conv2d_nchw_xx_o_o_o_i = s[conv2d_nchw].split(conv2d_nchw_xx_o_o_o_i, factor=1)conv2d_nchw_rc_o_i, conv2d_nchw_rc_i = s[conv2d_nchw].split(conv2d_nchw_rc, factor=2)conv2d_nchw_rc_o_o, conv2d_nchw_rc_o_i = s[conv2d_nchw].split(conv2d_nchw_rc_o_i, factor=4)conv2d_nchw_ry_o_i, conv2d_nchw_ry_i = s[conv2d_nchw].split(conv2d_nchw_ry, factor=1)conv2d_nchw_ry_o_o, conv2d_nchw_ry_o_i = s[conv2d_nchw].split(conv2d_nchw_ry_o_i, factor=1)conv2d_nchw_rx_o_i, conv2d_nchw_rx_i = s[conv2d_nchw].split(conv2d_nchw_rx, factor=1)conv2d_nchw_rx_o_o, conv2d_nchw_rx_o_i = s[conv2d_nchw].split(conv2d_nchw_rx_o_i, factor=3)s[conv2d_nchw].reorder(conv2d_nchw_nn_o_o_o_o, conv2d_nchw_ff_o_o_o_o, conv2d_nchw_yy_o_o_o_o, conv2d_nchw_xx_o_o_o_o, conv2d_nchw_nn_o_o_o_i, conv2d_nchw_ff_o_o_o_i, conv2d_nchw_yy_o_o_o_i, conv2d_nchw_xx_o_o_o_i, conv2d_nchw_nn_o_o_i, conv2d_nchw_ff_o_o_i, conv2d_nchw_yy_o_o_i, conv2d_nchw_xx_o_o_i, conv2d_nchw_rc_o_o, conv2d_nchw_ry_o_o, conv2d_nchw_rx_o_o, conv2d_nchw_rc_o_i, conv2d_nchw_ry_o_i, conv2d_nchw_rx_o_i, conv2d_nchw_nn_o_i, conv2d_nchw_ff_o_i, conv2d_nchw_yy_o_i, conv2d_nchw_xx_o_i, conv2d_nchw_rc_i, conv2d_nchw_ry_i, conv2d_nchw_rx_i, conv2d_nchw_nn_i, conv2d_nchw_ff_i, conv2d_nchw_yy_i, conv2d_nchw_xx_i)compute_i0_o_i, compute_i0_i = s[compute].split(compute_i0, factor=1)compute_i0_o_o_i, compute_i0_o_i = s[compute].split(compute_i0_o_i, factor=1)compute_i0_o_o_o, compute_i0_o_o_i = s[compute].split(compute_i0_o_o_i, factor=1)compute_i1_o_i, compute_i1_i = s[compute].split(compute_i1, factor=2)compute_i1_o_o_i, compute_i1_o_i = s[compute].split(compute_i1_o_i, factor=64)compute_i1_o_o_o, compute_i1_o_o_i = s[compute].split(compute_i1_o_o_i, factor=1)compute_i2_o_i, compute_i2_i = s[compute].split(compute_i2, factor=1)compute_i2_o_o_i, compute_i2_o_i = s[compute].split(compute_i2_o_i, factor=1)compute_i2_o_o_o, compute_i2_o_o_i = s[compute].split(compute_i2_o_o_i, factor=1)compute_i3_o_i, compute_i3_i = s[compute].split(compute_i3, factor=7)compute_i3_o_o_i, compute_i3_o_i = s[compute].split(compute_i3_o_i, factor=1)compute_i3_o_o_o, compute_i3_o_o_i = s[compute].split(compute_i3_o_o_i, factor=1)s[compute].reorder(compute_i0_o_o_o, compute_i1_o_o_o, compute_i2_o_o_o, compute_i3_o_o_o, compute_i0_o_o_i, compute_i1_o_o_i, compute_i2_o_o_i, compute_i3_o_o_i, compute_i0_o_i, compute_i1_o_i, compute_i2_o_i, compute_i3_o_i, compute_i0_i, compute_i1_i, compute_i2_i, compute_i3_i)s[conv2d_nchw].compute_at(s[compute], compute_i3_o_i)kernel_shared = s.cache_read(kernel, "shared", [conv2d_nchw])kernel_shared_ax0, kernel_shared_ax1, kernel_shared_ax2, kernel_shared_ax3 = tuple(kernel_shared.op.axis)s[kernel_shared].compute_at(s[conv2d_nchw], conv2d_nchw_rx_o_o)pad_temp_shared = s.cache_read(pad_temp, "shared", [conv2d_nchw])pad_temp_shared_ax0, pad_temp_shared_ax1, pad_temp_shared_ax2, pad_temp_shared_ax3 = tuple(pad_temp_shared.op.axis)s[pad_temp_shared].compute_at(s[conv2d_nchw], conv2d_nchw_rx_o_o)s[pad_temp].compute_inline()compute_i0_o_o_o_i1_o_o_o_fused_i2_o_o_o_fused_i3_o_o_o_fused = s[compute].fuse(compute_i0_o_o_o, compute_i1_o_o_o, compute_i2_o_o_o, compute_i3_o_o_o)s[compute].bind(compute_i0_o_o_o_i1_o_o_o_fused_i2_o_o_o_fused_i3_o_o_o_fused, te.thread_axis("blockIdx.x"))compute_i0_o_o_i_i1_o_o_i_fused_i2_o_o_i_fused_i3_o_o_i_fused = s[compute].fuse(compute_i0_o_o_i, compute_i1_o_o_i, compute_i2_o_o_i, compute_i3_o_o_i)s[compute].bind(compute_i0_o_o_i_i1_o_o_i_fused_i2_o_o_i_fused_i3_o_o_i_fused, te.thread_axis("vthread"))compute_i0_o_i_i1_o_i_fused_i2_o_i_fused_i3_o_i_fused = s[compute].fuse(compute_i0_o_i, compute_i1_o_i, compute_i2_o_i, compute_i3_o_i)s[compute].bind(compute_i0_o_i_i1_o_i_fused_i2_o_i_fused_i3_o_i_fused, te.thread_axis("threadIdx.x"))kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused = s[kernel_shared].fuse(kernel_shared_ax0, kernel_shared_ax1, kernel_shared_ax2, kernel_shared_ax3)kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o, kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused_i = s[kernel_shared].split(kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused, factor=1)s[kernel_shared].vectorize(kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused_i)kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_o, kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_i = s[kernel_shared].split(kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o, factor=64)s[kernel_shared].bind(kernel_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_i, te.thread_axis("threadIdx.x"))pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused = s[pad_temp_shared].fuse(pad_temp_shared_ax0, pad_temp_shared_ax1, pad_temp_shared_ax2, pad_temp_shared_ax3)pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o, pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_i = s[pad_temp_shared].split(pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused, factor=4)s[pad_temp_shared].vectorize(pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_i)pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_o, pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_i = s[pad_temp_shared].split(pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o, factor=64)s[pad_temp_shared].bind(pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_i, te.thread_axis("threadIdx.x"))s[conv2d_nchw].pragma(conv2d_nchw_nn_o_o_o_o, "auto_unroll_max_step", 512)s[conv2d_nchw].pragma(conv2d_nchw_nn_o_o_o_o, "unroll_explicit", True)CUDA source code:#ifdef _WIN32 using uint = unsigned int; using uchar = unsigned char; using ushort = unsigned short; using int64_t = long long; using uint64_t = unsigned long long;#else #define uint unsigned int #define uchar unsigned char #define ushort unsigned short #define int64_t long long #define uint64_t unsigned long long#endifextern "C" __global__ void __launch_bounds__(64) default_function_kernel0(float* __restrict__ data, float* __restrict__ kernel, float* __restrict__ compute, float* __restrict__ bias) { float conv2d_nchw[14]; __shared__ float pad_temp_shared[72]; __shared__ float kernel_shared[3072]; conv2d_nchw[0] = 0.000000e+00f; conv2d_nchw[1] = 0.000000e+00f; conv2d_nchw[2] = 0.000000e+00f; conv2d_nchw[3] = 0.000000e+00f; conv2d_nchw[4] = 0.000000e+00f; conv2d_nchw[5] = 0.000000e+00f; conv2d_nchw[6] = 0.000000e+00f; conv2d_nchw[7] = 0.000000e+00f; conv2d_nchw[8] = 0.000000e+00f; conv2d_nchw[9] = 0.000000e+00f; conv2d_nchw[10] = 0.000000e+00f; conv2d_nchw[11] = 0.000000e+00f; conv2d_nchw[12] = 0.000000e+00f; conv2d_nchw[13] = 0.000000e+00f; for (int rc_outer_outer = 0; rc_outer_outer
以下是恢复搜索的例子。这种情况下要自行创建搜索策略和 cost 模型,并通过日志文件恢复搜索策略和 cost 模型的状态。以下示例中,我们恢复状态,并多进行 5 次训练。
def resume_search(task, log_file): print("Resume search:") cost_model = auto_scheduler.XGBModel() cost_model.update_from_file(log_file) search_policy = auto_scheduler.SketchPolicy( task, cost_model, init_search_callbacks=[auto_scheduler.PreloadMeasuredStates(log_file)] ) measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300) tune_option = auto_scheduler.TuningOptions( num_measure_trials=5, runner=measure_ctx.runner, measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) task.tune(tune_option, search_policy=search_policy) # 终止测试过程 del measure_ctx# 我们不再在服务器上运行调优,因为太耗时间了# 去掉下行代码的注释自行运行# resume_search(task, log_file)
内容中包含的图片若涉及版权问题,请及时与我们联系删除