掘金 人工智能 06月01日 18:53
征程 6EM 常见 QConfig 配置解读与示例
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

本文深入解析了J6EM上量化感知训练(QAT)中QConfig的配置,旨在帮助开发者理解并灵活运用QConfig。文章首先介绍了基础模板,包括default和int16模板,阐述了它们在校准和QAT训练中的作用和区别。接着,详细讲解了敏感度模板,如何结合精度debug工具的敏感度结果进行混合精度调优。最后,介绍了自定义模板ModuleNameQconfigSetter,用于满足fixed scale、linear weight int16等特殊需求,并提供了可运行的示例,展示了如何在网络中配置不同算子的量化类型,以及如何为add算子配置固定scale。

💡**基础模板:** 介绍了`default`和`int16`两种基础模板,它们分别适用于不同的量化策略。`default`模板会将grid sample算子的grid输入到第一个gemm类算子或QuantStub之间的算子设置为int16,其余设置为int8;`int16`模板则直接将所有算子设置为int16。

🌡️**敏感度模板:** 敏感度模板允许根据量化敏感度结果,将敏感度最高的topk个算子设置为int16,从而实现混合精度调优。该模板需要结合精度debug工具产生的敏感度结果使用,可以针对模型的多个输出设置多个敏感度模板。

🛠️**自定义模板:** 通过`ModuleNameQconfigSetter`自定义模板,可以针对特定模块配置自定义的qconfig,满足诸如设置fixed scale、配置linear weight int16等特殊需求。自定义模板可以与固定模板、敏感度模板搭配使用,以实现更精细的量化控制。

🧪**可运行示例:** 文章提供了两个可运行的示例,分别展示了如何将网络中linear2的weight配置为int16量化,以及如何配置add单算子的输入和输出均使用固定scale。这些示例为开发者提供了实际操作的参考。

⚙️**QConfig配置:** QConfig控制了模型所有节点的量化类型,例如是采用 int8 还是 int16 量化,是固定校准阶段的 scale 去 qat 还是不固定 scale 去 qat。提供的模板可分为三类:基础模板、敏感度模板、自定义模板。

一、引言

在工具链用户手册《量化感知训练(QAT)-开发指南-QConfig 详解》章节专门介绍了在 J6EM 上 qconfig 是怎么回事,从经历看,大家可能会存在看了依旧不懂,或懂了不知道怎么配置的情况,特别是一些 OE 包中示例没有的配置,例如固定某节点 scale、配置 linear weight int16 等操作。

qconfig 控制了模型所有节点的量化类型,例如是采用 int8 还是 int16 量化,是固定校准阶段的 scale 去 qat 还是不固定 scale 去 qat。

提供的模板可分为三类:基础模板、敏感度模板、自定义模板。本文将常见配置通过示例方式进行呈现。

二、基础模板

基础模板中 calibration / qat / qat_fixed_act_scale 区别在于使用的 observer 类型和 scale 更新逻辑,分别用于校准,不固定 activation scaleqat 训练,固定 activation scale qat 训练。

default 模板 ( default_calibration_qconfig_setter / default_qat_qconfig_setter / default_qat_fixed_act_qconfig_setter ) 会做三件事:

int16 模板 ( qat_8bit_weight_16bit_act_qconfig_setter / qat_8bit_weight_16bit_fixed_act_qconfig_setter / calibration_8bit_weight_16bit_act_qconfig_setter ) 会做两件事:

from horizon_plugin_pytorch.quantization.qconfig_template import (    default_calibration_qconfig_setter,    default_qat_qconfig_setter,    default_qat_fixed_act_qconfig_setter,    qat_8bit_weight_16bit_act_qconfig_setter,    qat_8bit_weight_16bit_fixed_act_qconfig_setter,    calibration_8bit_weight_16bit_act_qconfig_setter,)qat_or_calib_model = prepare(    float_model,    example_inputs=example_inputs,  # 用来感知图结构    qconfig_setter=(        default_qat_qconfig_setter,    # 根据需要配置setter模板    ),)

三、敏感度模板

敏感度模板有三个:

sensitive_op_calibration_8bit_weight_16bit_act_qconfig_settersensitive_op_qat_8bit_weight_16bit_act_qconfig_settersensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter

三者的区别和基础模板中三者的区别类似,也是分别用于校准,不固定 activation scale qat 训练,固定 activation scale qat 训练。

敏感度模板的第一个输入是精度 debug 工具产生的敏感度结果,第二个参数可以指定 ratio 或 topk,敏感度模板会根据配置,将量化敏感度最高的 topk 个算子设置为 int16。搭配固定模板,可以实现混合精度调优。

若模型有多个输出,每个输出都会产生一个敏感度表,您可以设置多个敏感度模版。示例如下:

from horizon_plugin_pytorch.quantization.qconfig_template import (    default_calibration_qconfig_setter,    sensitive_op_qat_8bit_weight_16bit_act_qconfig_setter,    sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter,    sensitive_op_calibration_8bit_weight_16bit_act_qconfig_setter,)# 这两个pt文件是通过debug工具得到的table1 = torch.load("output_0-0_L1_sensitive_ops.pt")table2 = torch.load("output_0-1_L1_sensitive_ops.pt")calibration_model = prepare(    float_model,    example_inputs=example_input,    qconfig_setter=(        sensitive_op_calibration_8bit_weight_16bit_act_qconfig_setter(table1, ratio=0.2),        sensitive_op_calibration_8bit_weight_16bit_act_qconfig_setter(table2, ratio=0.2),        default_calibration_qconfig_setter,    ),)

四、自定义模板

自定义模板为 ModuleNameQconfigSetter,需要传入模块名和对应自定义的 qconfig,一般用于设置 fixed scale、配置 linear weight int16 等特殊需求,可以和固定模板,敏感度模板搭配使用。示例如下:

from horizon_plugin_pytorch.quantization.qconfig_template import (    calibration_8bit_weight_16bit_act_qconfig_setter,    ModuleNameQconfigSetter,)from horizon_plugin_pytorch.quantization.qconfig import (    get_qconfig,    MSEObserver,    MinMaxObserver,    FixedScaleObserver,    QConfig,)from horizon_plugin_pytorch.quantization.fake_quantize import FakeQuantize# 手动设置某个算子的输出scaleop_name_output_fix_scale_qconfig = QConfig(    output=FakeQuantize.with_args(        observer=FixedScaleObserver,        dtype=qint16,        scale=0.0625,    ))# 设置某个算子weight与输出activation的量化类型# 校准时用MSEObserver,qat时用MinMaxObserver# 没有weight的算子,配置了weight_dtype也不会起作用calib_weight_act_both_int16_qconfig = get_qconfig(    observer=MSEObserver,    weight_dtype=qint16,    out_dtype=qint16,)calib_weight_act_both_int8_qconfig = get_qconfig(    observer=MSEObserver,    weight_dtype=qint8,    out_dtype=qint8,)qat_weight_act_both_int16_qconfig = get_qconfig(    observer=MinMaxObserver,    weight_dtype=qint16,    out_dtype=qint16,    fix_scale=True,    # 是否固定scale)

放在一块简单示例如下:

from horizon_plugin_pytorch.quantization.qconfig_template import (    default_qat_qconfig_setter,    sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter,    ModuleNameQconfigSetter,)table = torch.load("output_0-0_dataindex_1_sensitive_ops.pt")# 自动替换生成的算子只能通过 ModuleNameQconfigSetter 配置自定义 qconfig。module_name_to_qconfig = {    "_generated_add_0": op_name_output_fix_scale_qconfig ,}qat_model = prepare(    float_model,    example_inputs=example_input,    qconfig_setter=(        ModuleNameQconfigSetter(module_name_to_qconfig),        sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter(table, ratio=0.2),        default_qat_qconfig_setter,    ),)

五、可运行的示例

import torchfrom horizon_plugin_pytorch import set_march, Marchset_march(March.NASH_M)from horizon_plugin_pytorch.quantization import prepare, set_fake_quantize, FakeQuantStatefrom horizon_plugin_pytorch.quantization import QuantStubfrom horizon_plugin_pytorch.quantization.hbdk4 import exportfrom horizon_plugin_pytorch.quantization.qconfig_template import calibration_8bit_weight_16bit_act_qconfig_setter, ModuleNameQconfigSetterfrom horizon_plugin_pytorch.quantization.qconfig import get_qconfig, MSEObserver, MinMaxObserverfrom horizon_plugin_pytorch.dtype import qint8, qint16from torch.quantization import DeQuantStubimport torch.nn as nn# 定义网络结构class SmallModel(nn.Module):    def __init__(self):        super(SmallModel, self).__init__()        # 第一个 Linear: 输入 [2, 100, 256] -> 输出 [2, 100, 256]        self.linear1 = nn.Linear(256, 256)        self.layernorm = nn.LayerNorm(256)  # 对最后一维进行归一化        self.relu = nn.ReLU()        # 第二个 Linear: 输入 [2, 100, 256] -> 输出 [2, 100, 60]        self.linear2 = nn.Linear(256, 60)        # 第三个 Linear: 输入 [2, 100, 60] -> 输出 [2, 100, 60]        self.linear3 = nn.Linear(60, 60)        self.quant = QuantStub()        self.dequant = DeQuantStub()    def forward(self, x):        x = self.quant(x)        # 第一个 Linear        x = self.linear1(x)  # [2, 100, 256]        x = self.layernorm(x)  # [2, 100, 256]        x = self.relu(x)  # [2, 100, 256]        # 第二个 Linear        x = self.linear2(x)  # [2, 100, 60]        # 第三个 Linear        x = self.linear3(x)        x = self.dequant(x)        return xexample_input = torch.randn(2, 100, 256)model = SmallModel()# 前向传播output = model(example_input)print("输出形状:", output.shape)# A global march indicating the target hardware version must be setted before prepare qat.set_march(March.NASH_M)calib_weight_act_both_int16_qconfig = get_qconfig(    observer=MSEObserver,    weight_dtype=qint16,    out_dtype=qint16,)# layernorm没有weight,配置了weight_dtype也不会起作用calib_weight_act_both_int8_qconfig = get_qconfig(    observer=MSEObserver,    weight_dtype=qint8,    out_dtype=qint8,)qat_weight_act_both_int16_qconfig = get_qconfig(    observer=MinMaxObserver,    weight_dtype=qint16,    out_dtype=qint16,    fix_scale=True,)# 节点名称,可以从model_check_result.txt中获取,也可以从敏感度文件中获取module_name_to_qconfig = {    "layernorm": calib_weight_act_both_int8_qconfig,    "linear2": calib_weight_act_both_int16_qconfig,   }calib_model = prepare(model.eval(), example_input,                      qconfig_setter=(                          ModuleNameQconfigSetter(module_name_to_qconfig),                          calibration_8bit_weight_16bit_act_qconfig_setter,                          ),                      )calib_model.eval()set_fake_quantize(calib_model, FakeQuantState.CALIBRATION)calib_model(example_input)calib_model.eval()                            set_fake_quantize(calib_model, FakeQuantState.VALIDATION)calib_out = calib_model(example_input)qat_bc = export(calib_model, example_input)
import torchfrom horizon_plugin_pytorch import set_march, Marchset_march(March.NASH_E)from horizon_plugin_pytorch.quantization import prepare, set_fake_quantize, FakeQuantStatefrom horizon_plugin_pytorch.quantization import QuantStubfrom horizon_plugin_pytorch.quantization.hbdk4 import exportfrom horizon_plugin_pytorch.quantization.qconfig_template import calibration_8bit_weight_16bit_act_qconfig_setter, ModuleNameQconfigSetterfrom horizon_plugin_pytorch.quantization.qconfig import get_qconfig, MSEObserver, MinMaxObserver, FixedScaleObserver, QConfigfrom horizon_plugin_pytorch.quantization.fake_quantize import FakeQuantizefrom horizon_plugin_pytorch.dtype import qint8, qint16from torch.quantization import DeQuantStubimport torch.nn as nnclass AddNet(nn.Module):    def __init__(self):        super(AddNet, self).__init__()        self.quant_x = QuantStub()        self.quant_y = QuantStub()        self.dequant = DeQuantStub()    def forward(self, x, y):        x = self.quant_x(x)        y = self.quant_y(y)        z = torch.add(x, y)        z = self.dequant(z)        return z# 创建模型model = AddNet()# 生成两个相同形状的输入张量torch.manual_seed(42)x = torch.randn(1, 1, 2, 6)y = torch.randn(1, 2, 2, 6)example_input = (x,y)# 前向传播output = model(example_input[0], example_input[1])print("float输出数据:", output)print("输入形状:", example_input[0].shape)print("输出形状:", output.shape)# A global march indicating the target hardware version must be setted before prepare qat.set_march(March.NASH_E)add_input_fix_scale_qconfig = QConfig(    output=FakeQuantize.with_args(        observer=FixedScaleObserver,        dtype=qint16,        scale=0.03125,    ))add_output_fix_scale_qconfig = QConfig(    output=FakeQuantize.with_args(        observer=FixedScaleObserver,        dtype=qint16,        scale=0.0625,    ))# 节点名称,可以从model_check_result.txt中获取,也可以从敏感度文件中获取module_name_to_qconfig = {    "quant_x": add_input_fix_scale_qconfig,    "quant_y": add_input_fix_scale_qconfig,    "_generated_add_0": add_output_fix_scale_qconfig,}calib_model = prepare(model.eval(), example_input,                      qconfig_setter=(                          ModuleNameQconfigSetter(module_name_to_qconfig),                          calibration_8bit_weight_16bit_act_qconfig_setter,                          ),                      )calib_model.eval()set_fake_quantize(calib_model, FakeQuantState.CALIBRATION)calib_model(example_input[0], example_input[1])calib_model.eval()                            set_fake_quantize(calib_model, FakeQuantState.VALIDATION)calib_out = calib_model(example_input[0], example_input[1])print("calib输出数据:", calib_out)qat_bc = export(calib_model, example_input)

六、冻结部分网络结构 qat 的配置

补充常见冻结网络结构,去进行 qat 的做法

from horizon_plugin_pytorch.quantization import (    QuantStub,    prepare,    set_fake_quantize,    FakeQuantState,)#prepare QAT模型qat_model = prepare(    model,    example_inputs=xxx,    qconfig_setter=(        xxx,    ))#加载calib权重qat_model.load_state_dict(torch.load("calib-checkpoint.ckpt"))#QAT训练qat_model.train()#固定backbone部分的权重,requires_grad不影响drop bn的行为,需要与eval联合用for param in qat_model.backbone.parameters():    param.requires_grad = False#固定backbone部分的scale,eval只影响drop bn的行为,如果发生了backward仍然会改变权重,需要与requires_grad联合使用qat_model.backbone.eval()set_fake_quantize(qat_model.backbone, FakeQuantState.VALIDATION)#配置head的FakeQuant为QAT状态set_fake_quantize(qat_model.head, FakeQuantState.QAT)

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

量化感知训练 QConfig J6EM 模型量化 PyTorch
相关文章