逐层蒸馏示例请参见获取更多样例。
调用流程如图1所示。
1
|
import amct_pytorch as amct |
建议使用原始待量化的模型和测试集,在PyTorch环境下推理,验证环境、推理脚本是否正常。
推荐执行该步骤,请确保原始模型可以完成推理且精度正常;执行该步骤时,可以使用部分测试集,减少运行时间。
1
|
user_do_inference_torch(ori_model, test_data, test_iterations) |
1 2 3 4 5 6 |
config_file = './tmp/config.json' simple_cfg = './distill.cfg' amct.create_distill_config(config_file=config_file, model=ori_model, input_data=ori_model_input_data, config_defination=simple_cfg) |
1 2 3 4 |
compress_model = amct.create_distill_model( config_file=config_file, model=ori_model, input_data=ori_model_input_data) |
调用distill接口进行逐层蒸馏。针对配置中的蒸馏结构进行蒸馏。
1 2 3 4 5 6 7 8 9 10 |
distill_model = amct.distill( model=ori_model, compress_model config_file=config_file, train_loader, epochs=1, lr=1e-3, sample_instance=None, loss=loss, optimizer=optimizer) |
1 2 3 4 5 6 7 8 9 |
amct.save_distill_model( model, "./model/distilled" input_data, record_file="./results/records.txt" input_names=['input'], output_names=['output'], dynamic_axes={'input':{0: 'batch_size'}, 'output':{0: 'batch_size'}}) |
使用量化后仿真模型精度与2中的原始精度做对比,可以观察量化对精度的影响。
1 2 |
quant_model = './results/user_model_fake_quant_model.onnx' user_do_inference_onnx(quant_model, test_data, test_iterations) |