单算子模式的量化感知训练
功能介绍
参考量化感知训练章节进行基础量化时,量化的内部处理逻辑需要将原始模型转换成ONNX模型,并在ONNX模型基础上进行图的修改操作,此时若模型中存在Pytorch自定义算子时,可能存在无法导出生成ONNX模型,从而导致量化失败的问题。
单算子模式的量化感知训练功能,提供由Pytorch原生算子转换生成的自定义QAT算子,基于该算子进行量化因子的重训练,量化因子作为算子参数保存在QAT单算子中,无需导出ONNX模型,可以避免上述量化感知训练方案中的算子导出异常问题。训练完成后,通过torch.onnx.export机制,建立QAT算子与ONNX原生算子的映射关系,将Pytorch模型中QAT算子计算获得的参数传递给ONNX原生量化算子,完成模型导出。简易示意图如图1所示。
实现原理如下图所示(以torch.nn.Conv2d算子为例进行说明):
在torch原生算子torch.nn.Conv2d中插入ARQ权重量化,ULQ数据量化算子,进行量化因子的重训练,将量化因子作为算子参数保存在Conv2dQAT单算子中,最终调用torch.onnx.export接口,将QAT自定义算子通过映射关系导出成ONNX原生量化算子QuantizeLinear与DequantizeLinear,带有此两个量化算子的模型又称QDQ(Quantize and DeQuantize,简称QDQ)ONNX模型,QDQ格式模型具体介绍请参见Link。
QDQ ONNX模型不能直接通过ATC工具转成适配昇腾AI处理器的离线模型,需要通过AMCT(ONNX)中的QAT模型适配CANN模型特性,将该模型适配成CANN模型后,然后才能使用ATC工具进行下一步的转换。
该功能支持Training from scratch和Fine-tune两种使用方法:
- Training from scratch:使用QAT算子直接构图,从零开始训练。您可以在模型构建脚本中调用单算子模式提供的直接构造接口构造QAT算子,使用该算子进行模型构建。
- Fine-tune:在已有网络基础上,对待量化算子进行替换,相比于Training from scratch更为常用。如果您已经完成模型构建,您可以调用单算子模式提供的基于原生算子构造接口,进行QAT算子构造;之后可以参考下文样例中的算子替换方案,对网络模型中的待量化算子进行替换。
QAT算子规格如下表所示,调用示例请参见样例列表。
待量化算子类型 |
替换后算子类型 |
限制 |
备注 |
---|---|---|---|
torch.nn.Conv2d |
Conv2dQAT |
|
复用层(共用weight和bias参数)不支持量化。 |
torch.nn.ConvTranspose2d |
ConvTranspose2dQAT |
|
|
torch.nn.Conv3d |
Conv3dQAT |
|
|
torch.nn.Linear |
LinearQAT |
不支持channel wise |
|
torch.nn.GRU |
GRUQAT |
|
- |
样例参考
Training from scratch方式代码样例:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 | import torch import torch.nn as nn import amct_pytorch as amct from amct_onnx.convert_model import convert_qat_model from amct_pytorch.nn.module.quantization.conv2d import Conv2dQAT from amct_pytorch.nn.module.quantization.linear import LinearQAT # QAT算子量化配置项 config = { "retrain_data_config": { "dst_type": "INT8", "batch_num": 10, "fixed_min": False, "clip_min": None, "clip_max": None }, "retrain_weight_config": { "dst_type": "INT8", "weights_retrain_algo": "arq_retrain", "channel_wise": False } } # 使用QAT单算子构造Lenet网络 net = torch.nn.Sequential( Conv2dQAT(1, 6, kernel_size=5, padding=2, config=config), nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2), Conv2dQAT(6, 16, kernel_size=5, config=config), nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(), LinearQAT(16 * 5 * 5, 120, config=config), nn.Sigmoid(), LinearQAT(120, 84, config=config), nn.Sigmoid(), LinearQAT(84, 10, config=config) ) # 训练 train(net, train_data, test_data) # 导出中间模型 torch.onnx.export(model, data, 'inter_model.onnx') # 导出fake quant模型与deploy模型 # 生成lenet_fake_quant_model.onnx,可在 ONNX 执行框架 ONNX Runtime 进行精度仿真的模型。 # 生成lenet_deploy_model.onnx,可在AI 处理器部署的模型文件。 convert_qat_model('inter_model.onnx', './outputs/lenet') # 使用fake quant模型进行精度仿真 validate_onnx('./outputs/lenet_fake_quant_model.onnx', val_data) |
Fine-tune方式代码样例:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 | import torch from torchvision.models.resnet import resnet101 import amct_pytorch as amct from amct_onnx.convert_model import convert_qat_model from amct_pytorch.nn.module.quantization.conv2d import Conv2dQAT model = resnet101(pretrained=True) # QAT算子量化配置项 config = { "retrain_data_config": { "dst_type": "INT8", "batch_num": 10, "fixed_min": False, "clip_min": None, "clip_max": None }, "retrain_weight_config": { "dst_type": "INT8", "weights_retrain_algo": "arq_retrain", "channel_wise": True } } def _set_module(model, submodule_key, module): # 将模型中原生算子替换为qat算子 tokens = submodule_key.split('.') sub_tokens = tokens[:-1] cur_mod = model for s in sub_tokens: cur_mod = getattr(cur_mod, s) setattr(cur_mod, tokens[-1], module) for name, module in model.named_modules(): # 遍历原图中各节点,将类型为Conv2d的torch原生算子替换为自定义QAT单算子 if isinstance(module, torch.nn.Conv2d): qat_module = Conv2dQAT.from_float( module, config=config) _set_module(model, name, qat_module) # 训练流程 train_and_val(model, train_data, test_data) # 导出中间模型 torch.onnx.export(model, data, 'inter_model.onnx') # 导出fake quant模型与deploy模型 # 生成resnet101_fake_quant_model.onnx,可在 ONNX 执行框架 ONNX Runtime 进行精度仿真的模型。 # 生成resnet101_deploy_model.onnx,可在AI 处理器部署的模型文件。 convert_qat_model('inter_model.onnx', './outputs/resnet101') # 使用fake quant模型进行精度仿真 validate_onnx('./outputs/resnet101_fake_quant_model.onnx', val_data) |