掘金 人工智能 05月23日
征程 6 J6E/M linear 双int16量化支持替代方案
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

本文针对地平线BPU上Linear算子量化时,input和weight都需要int16量化但硬件不支持的问题,提出了一种使用broadcast mul sum替代方案。该方案在float层面与linear等价,但在量化方式上有所区别,通常能带来精度正向优化。文章详细介绍了该替代方案的原理、实现方式,并提供了完整的代码示例,展示了如何在float训练完成后进行替换,并进行calib+qat。同时,也讨论了该方案的注意事项,以及mul输出scale设置问题。最后,文章还提供了三种不同量化情况下的完整示例,供读者参考。

💡当plugin精度debug工具定位到某个linear敏感时,且该linear的weight排在前面,并使用int8量化,可以考虑使用int16来量化该敏感linear的weight,以提升精度。

✔️针对Linear input和weight均需要int16量化,但地平线BPU硬件不支持的情况,可使用broadcast mul sum来替代验证,无需重训float。这种替换方案在float层面等价,量化方式上Linear weight是per channel量化,而替换后的weight作为mul输入时,是per tensor量化,通常精度会正向优化。

⚠️使用broadcast mul sum替换方案时,如果mul的输出绝大多数数值都在0附近,可能会导致MSE校准受异常值影响较大,进而影响精度。解决方案是为mul输出设置fixed scale为7/32767,以减小sigmoid或add+sigmoid的影响。

1.背景简介

当发现使用 plugin 精度 debug 工具定位到是某个 linear 敏感时,示例如下:

op_name                                sensitive_type    op_type                                                                          L1  quant_dtype    flops-------------------------------------  ---------------   -----------------------------  ----------------  -------------------------  -------  -------------  --------------model.layernorm.rsqrt                  activation        <class 'horizon_plugin_pytorch.nn.qat.segment_lut.SegmentLUT'>              6.52537  qint16         0(0%)model.linear2                          weight            <class 'horizon_plugin_pytorch.nn.qat.linear.Linear'>                       5.02445  qint8          3072000(0.00%)model.layernorm.var_mean.pre_mean      activation        <class 'horizon_plugin_pytorch.nn.qat.functional_modules.FloatFunctional'>  3.1683   qint16         0(0%)

可以发现,model.linear2 weight 排在了前面,且是 int8 量化。

接下来看下 baseline_statistic.txt 与 analysis_statistic.txt,其中有 model.linear2 的 input、weight、output 的数值分布范围,示例如下:

| Op Name                            | Mod Name       | Attr     | Min            | Max            | Mean           | Var        | Shape                       ||---------------------------------------------------------------------------------------------------------------------------------------------------------------| torch.nn.modules.linear.Linear     | model.linear2  | input    | 0.0000000      | 15.4210167     | 4.0793311      | 0.2532279  | torch.Size([2, 100, 256])   || torch.nn.modules.linear.Linear     | model.linear2  | weight   | -41.6590347    | 31.2311363     | -0.0053362     | 0.4427260  | torch.Size([60, 256])       || torch.nn.modules.linear.Linear     | model.linear2  | bias     | -0.4426649     | 0.3714900      | 0.0053294      | 0.0112585  | torch.Size([60])            || torch.nn.modules.linear.Linear     | model.linear2  | output   | -32.0065079    | 5.7881856      | 0.4558742      | 3.8736136  | torch.Size([2, 100, 60])    |

解决方案:使用 int16 来量化这个敏感 linear 的 weight。

如果必须要求 linear input weight output 都是 int16 量化,怎么办呢?

2.知识基础

在 征程 6E/M 上,地平线 BPU 对 linear 支持的情况如下:

本文发布时是这样的

可以看到:input 和 weight 不能同时为 int16。

3.Linear input weight both int16

对于 linear input 和 weight 均需要 int16 量化的情况,可使用 broadcast mul sum 来替代验证,无需重训 float。

异同简介:broadcast_mul_sum_replace_linear 在 float 层面可以等价替换 linear,但在量化方式上存在区别:Linear weight 是 per channel 量化,weight 作为 mul 输入时,是 per tensor 量化。一般情况下:weight int8 perchannel 变成 per tensor int16,精度是正向优化。

替换方案:在 float 训练完成后替换,然后进行 calib+qat。

class SmallModel(nn.Module):    def __init__(self, linear2_weight, linear2_bias):        super(SmallModel, self).__init__()        # 第一个 Linear: 输入 [2, 100, 256] -> 输出 [2, 100, 256]        self.linear1 = nn.Linear(256, 256)        self.layernorm = nn.LayerNorm(256)  # 对最后一维进行归一化        self.relu = nn.ReLU()        # 第二个 Linear: 输入 [2, 100, 256] -> 输出 [2, 100, 60]        # self.linear2 = nn.Linear(256, 60)        self.linear2_weight = linear2_weight        self.linear2_bias = linear2_bias        # 第三个 Linear: 输入 [2, 100, 60] -> 输出 [2, 100, 60]        self.linear3 = nn.Linear(60, 60)        self.quant = QuantStub()        self.dequant = DeQuantStub()        self.quant_linear2_weight = QuantStub()        self.quant_linear2_bias = QuantStub()        def forward(self, x):        x = self.quant(x)        linear2_weight = self.quant_linear2_weight(self.linear2_weight)        linear2_bias = self.quant_linear2_bias(self.linear2_bias)        # 第一个 Linear        x = self.linear1(x)  # [2, 100, 256]        x = self.layernorm(x)  # [2, 100, 256]        x = self.relu(x)  # [2, 100, 256]                # 第二个 Linear        # x = self.linear2(x)  # [2, 100, 60]        # ===================================        # 使用 broadcast mul + sum 替换linear        # ===================================        # 广播乘法:输入 [2, 100, 256] 与权重 [60, 256] 进行广播        broadcast_mul = x.reshape(2, 100, 1, 256) * linear2_weight.reshape(1, 1, 60, 256)  # [2, 100, 60, 256]        # 按最后一个维度求和:sum 操作模拟线性层的加权求和        sum_output = broadcast_mul.sum(dim=-1)  # [2, 100, 60]        # 加上偏置        x = sum_output + linear2_bias  # [2, 100, 60]                # 第三个 Linear        x = self.linear3(x)        x = self.dequant(x)        return x

broadcast mul sum 替换方案,均支持 int16。

注意事项:如果 mul 的输出 绝大多数 数值都在 0 附近 -> MSE 校准受异常值影响较大 -> 输出 scale 非常大 -> 0 附近的大量小数值被舍入成 0 -> sum 和发生巨大偏差。

影响范围:mul 后面跟着 sigmoid 或 add+sigmoid 时影响很大。

解决方案:mul 输出设置 fixed scale 为 7/32767,因为 sigmoid 并不需要太大的输入,而 mul 的输出分布需要小 scale。

4.全流程示例

从表中可以看到,在 linear 需要 int16 量化的场景,input/output int16 对应的 latency 最短,其次是 weight output int16 input int8,最差的是三者都需要 int16,针对这三种情况,下面分别提供完整的例子供参考。

信息描述

注意:非完全等价,仅作为参考

4.1 示例代码

import torchfrom horizon_plugin_pytorch import set_march, Marchset_march(March.NASH_M)from horizon_plugin_pytorch.quantization import prepare, set_fake_quantize, FakeQuantStatefrom horizon_plugin_pytorch.quantization import QuantStubfrom horizon_plugin_pytorch.quantization.hbdk4 import exportfrom horizon_plugin_pytorch.quantization.qconfig_template import calibration_8bit_weight_16bit_act_qconfig_setter, ModuleNameQconfigSetterfrom horizon_plugin_pytorch.quantization.qconfig import get_qconfig, MSEObserver, MinMaxObserverfrom horizon_plugin_pytorch.dtype import qint8, qint16from torch.quantization import DeQuantStubimport torch.nn as nnfrom horizon_plugin_pytorch.quantization import hbdk4 as hb4from hbdk4.compiler import convert, save, hbm_perf, visualize, compileimport torchimport torch.nn as nn# 定义网络结构class SmallModel(nn.Module):    def __init__(self, linear2_weight, linear2_bias):        super(SmallModel, self).__init__()        # 第一个 Linear: 输入 [2, 100, 256] -> 输出 [2, 100, 256]        self.linear1 = nn.Linear(256, 256)        self.layernorm = nn.LayerNorm(256)  # 对最后一维进行归一化        self.relu = nn.ReLU()        # 第二个 Linear: 输入 [2, 100, 256] -> 输出 [2, 100, 60]        # self.linear2 = nn.Linear(256, 60)        self.linear2_weight = linear2_weight        self.linear2_bias = linear2_bias        # 第三个 Linear: 输入 [2, 100, 60] -> 输出 [2, 100, 60]        self.linear3 = nn.Linear(60, 60)        self.quant = QuantStub()        self.dequant = DeQuantStub()        self.quant_linear2_weight = QuantStub()        self.quant_linear2_bias = QuantStub()        def forward(self, x):        x = self.quant(x)        linear2_weight = self.quant_linear2_weight(self.linear2_weight)        linear2_bias = self.quant_linear2_bias(self.linear2_bias)        # 第一个 Linear        x = self.linear1(x)  # [2, 100, 256]        x = self.layernorm(x)  # [2, 100, 256]        x = self.relu(x)  # [2, 100, 256]                # 第二个 Linear        # x = self.linear2(x)  # [2, 100, 60]        # ===================================        # 使用 broadcast mul + sum 替换linear        # ===================================        # 广播乘法:输入 [2, 100, 256] 与权重 [60, 256] 进行广播        broadcast_mul = x.reshape(2, 100, 1, 256) * linear2_weight.reshape(1, 1, 60, 256)  # [2, 100, 60, 256]        # 按最后一个维度求和:sum 操作模拟线性层的加权求和        sum_output = broadcast_mul.sum(dim=-1)  # [2, 100, 60]        # 加上偏置        x = sum_output + linear2_bias  # [2, 100, 60]                # 第三个 Linear        x = self.linear3(x)        x = self.dequant(x)        return xfloat_ckpt_path = "model_path/float-checkpoint.ckpt" float_state_dict = torch.load(float_ckpt_path)# 遍历 OrderedDict,查找包含 "linear2" 的键for key, value in float_state_dict.items():    # if "linear2" in key:    #     print(f"Key: {key}, Value: {value.shape}")    if key == "linear2.weight":        linear2_weight = value    if key == "linear2.bias":        linear2_bias = value# example_input = torch.randn(2, 100, 256)file_path = "random_data.pt"example_input = torch.load(file_path)model = SmallModel(linear2_weight, linear2_bias)missing_keys, unexpected_keys = model.load_state_dict(float_state_dict, strict=False)print("missing_keys & unexpected_keys:", missing_keys, '\n', unexpected_keys)# 前向传播output = model(example_input)print("float输出数据:", output)torch.save(output, "model_path/6_model_float_output.pt")print("输入形状:", example_input.shape)print("输出形状:", output.shape)# A global march indicating the target hardware version must be setted before prepare qat.set_march(March.NASH_M)calib_model = prepare(model.eval(), example_input,                      qconfig_setter=(                          calibration_8bit_weight_16bit_act_qconfig_setter,                          ),                      )calib_model.eval()set_fake_quantize(calib_model, FakeQuantState.CALIBRATION)calib_model(example_input)calib_model.eval()        set_fake_quantize(calib_model, FakeQuantState.VALIDATION)calib_out = calib_model(example_input)print("calib输出数据:", calib_out)qat_bc = export(calib_model, example_input)hb_quantized_model = convert(qat_bc, March.NASH_M)

4.2 比较替代方案的输出一致性

float输出数据: tensor([[[-0.3016,  0.1338, -0.5251,  ..., -0.0551, -0.2093, -0.0308],         [-0.1969, -0.0131, -0.3287,  ...,  0.3234, -0.0869, -0.0637],         [-0.3056,  0.1478, -0.2673,  ...,  0.2355, -0.3487,  0.0134],         ...,         [-0.3990, -0.0389, -0.1686,  ..., -0.0046, -0.4131,  0.0482],         [-0.1059,  0.2431, -0.1886,  ...,  0.0787, -0.3454,  0.0231],         [-0.2134, -0.1071, -0.0575,  ...,  0.3434, -0.1661,  0.2248]]],       grad_fn=<ViewBackward0>)       calib输出数据: tensor([[[-0.3038,  0.1370, -0.5269,  ..., -0.0571, -0.2111, -0.0296],         [-0.1975, -0.0111, -0.3280,  ...,  0.3215, -0.0884, -0.0637],         [-0.3052,  0.1488, -0.2677,  ...,  0.2348, -0.3479,  0.0132],         ...,         [-0.3988, -0.0393, -0.1662,  ..., -0.0055, -0.4117,  0.0484],         [-0.1058,  0.2442, -0.1890,  ...,  0.0780, -0.3447,  0.0240],         [-0.2142, -0.1061, -0.0587,  ...,  0.3422, -0.1657,  0.2255]]],       grad_fn=<ViewBackward0>)
float输出数据: tensor([[[-0.3016,  0.1338, -0.5251,  ..., -0.0551, -0.2093, -0.0308],         [-0.1969, -0.0131, -0.3287,  ...,  0.3234, -0.0869, -0.0637],         [-0.3056,  0.1478, -0.2673,  ...,  0.2355, -0.3487,  0.0134],         ...,         [-0.3990, -0.0389, -0.1686,  ..., -0.0046, -0.4131,  0.0482],         [-0.1059,  0.2431, -0.1886,  ...,  0.0787, -0.3454,  0.0231],         [-0.2134, -0.1071, -0.0575,  ...,  0.3434, -0.1661,  0.2248]]],       grad_fn=<ViewBackward0>)calib输出数据: tensor([[[-0.3038,  0.1370, -0.5269,  ..., -0.0571, -0.2111, -0.0296],         [-0.1975, -0.0111, -0.3280,  ...,  0.3215, -0.0884, -0.0637],         [-0.3051,  0.1487, -0.2678,  ...,  0.2349, -0.3478,  0.0132],         ...,         [-0.3988, -0.0392, -0.1662,  ..., -0.0055, -0.4117,  0.0484],         [-0.1058,  0.2442, -0.1890,  ...,  0.0780, -0.3447,  0.0240],         [-0.2142, -0.1061, -0.0586,  ...,  0.3423, -0.1657,  0.2255]]],       grad_fn=<ViewBackward0>)

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

Linear量化 int16量化 broadcast mul sum 地平线BPU 精度优化
相关文章