量化感知训练会重新训练量化模型,从而减小模型大小,并且加快推理过程。当前支持对PyTorch框架的CNN类模型进行量化,并将量化后的模型保存为.onnx文件,量化过程中,需要用户自行提供模型与数据集,调用API接口完成模型的量化调优。
目前支持对包括但不限于表1 已验证模型列表中的模型进行模型量化感知训练。
from msmodelslim.pytorch.quant.qat_tools import qsin_qat, QatConfig, get_logger
quant_config = QatConfig(grad_scale=0.001) quant_logger = get_logger() model = qsin_qat(model, quant_config, quant_logger).to(device_calc) #根据实际情况配置待量化模型实例、量化配置和量化输出日志,注意需把模型按照原训练流程部署在NPU设备
bash ./test/train_full_1p.sh --data_path=/datasets/imagenet #请根据实际情况配置数据集路径
import argparse import os import torch import models.image_classification.resnet as nvmodels # 初始化模型 parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') parser.add_argument('-b', '--batch-size', default=1, type=int, metavar='N', help='onnx bs') parser.add_argument('--pretrained', default="./org_model_best.pth.tar", type=str, help='use pre-trained model') parser.add_argument('--quant_ckpt', default="./checkpoint_77.244_asym.pth.tar", type=str, help='use pre-trained model') args = parser.parse_args() model = nvmodels.build_resnet("resnet50", "classic", is_training=False) pretrained_dict = torch.load(args.pretrained, map_location='cpu')["state_dict"] model.load_state_dict(pretrained_dict, strict=False) #保存量化后的onnx模型 from msmodelslim.pytorch.quant.qat_tools import save_qsin_qat_model #根据实际情况配置导出后模型文件名(文件后缀需为.onnx)、输入的shape、伪量化模型权重和onnx的输入名称 save_onnx_name='./resnet50.onnx' dummy_input = torch.ones([args.batch_size, 3, 224, 224]).type(torch.float32) saved_ckpt = args.quant_ckpt input_names=['input1'] save_qsin_qat_model(model, save_onnx_name, dummy_input, saved_ckpt, input_names)
python3 quant_deploy.py