AMCT目前主要支持基于重训练的通道稀疏模型压缩特性,稀疏示例请参见获取更多样例>resnet101,支持通道稀疏的层以及约束如下:
优化方式 |
支持的层类型 |
约束 |
---|---|---|
通道稀疏 |
torch.nn.Linear:全连接层 |
复用层(共用weight和bias参数)不支持稀疏。 |
torch.nn.Conv2d:卷积层 |
复用层(共用weight和bias参数)不支持稀疏。 depthwise只能被动稀疏(groups=in_channels),不能主动稀疏。 |
通道稀疏功能接口调用流程如图1所示。
import amct_pytorch as amct
推荐执行该步骤,请确保原始模型可以完成推理且精度正常;执行该步骤时,可以使用部分测试集,减少运行时间。
ori_model.load() # 测试模型 user_test_model(ori_model, test_data, test_iterations)
simple_cfg = './retrain.cfg' record_file = './tmp/record.txt' prune_retrain_model = amct.create_prune_retrain_model(model=ori_model, input_data=ori_model_input_data, config_defination=simple_cfg, record_file=record_file)
prune_retrain_model = amct.save_prune_retrain_model( model=pruned_retrain_model, save_path=save_path, input_data=input_data)
(由用户补充处理)基于ONNX Runtime的环境,使用通道稀疏后模型(prune_retrain_model)在测试集(test_data)上做推理,测试量化后仿真模型的精度。
使用稀疏后仿真模型精度与2中的原始精度做对比,可以观察通道稀疏对精度的影响。
prune_retrain_model = './results/user_model_fake_prune_model.onnx' user_do_inference_onnx(prune_retrain_model, test_data, test_iterations)
如果训练过程中断,需要从ckpt中恢复数据,继续训练,则调用流程为:
import amct_pytorch as amct
ori_model= user_create_model()
model = ori_model input_data = ori_model_input_data record_file = './tmp/record.txt' config_defination = './prune_cfg.cfg' save_pth_path = /your/path/to/save/tmp.pth model.load_state_dict(torch.load(state_dict_path)) prune_retrain_model = amct.restore_prune_retrain_model(model=ori_model, input_data=ori_model_input_data, record_file=record_file, config_defination='./prune_cfg.cfg', save_pth_path=/your/path/to/save/tmp.pth, 'state_dict')
prune_retrain_model = amct.save_prune_retrain_model( model=pruned_retrain_model, save_path=save_path, input_data=input_data)
(由用户补充处理)基于ONNX Runtime的环境,使用通道稀疏后模型(prune_retrain_model)在测试集(test_data)上做推理,测试量化后仿真模型的精度。
使用稀疏后仿真模型精度与2中的原始精度做对比,可以观察通道稀疏对精度的影响。
prune_retrain_model = './results/user_model_fake_prune_model.onnx' user_do_inference_onnx(prune_retrain_model, test_data, test_iterations)
如果稀疏后输出的模型为pth格式,则需要参考该章节,如果调用save_prune_retrain_model接口,则不需要。
由于输出的pth模型无法直接用于推理,需要用户自行将pth模型转成ONNX网络模型,或者调用save_prune_retrain_model接口保存为最终ONNX仿真模型以及部署模型,然后才能使用ATC工具进行模型转换。调用save_prune_retrain_model接口的调用示例如下:
prune_retrain_model = amct.10.6.3-save_prune_retrain_model( model=pruned_retrain_model, save_path=save_path, input_data=input_data)