蓝色部分为用户实现,灰色部分为用户调用AMCT提供的API实现:
import amct_pytorch as amct
推荐执行该步骤,请确保原始模型可以完成推理且精度正常;执行该步骤时,可以使用部分测试集,减少运行时间。
ori_model.load() # 测试模型 user_test_model(ori_model, test_data, test_iterations)
实现该步骤前,应先恢复训练好的参数,如2中的ori_model.load()。
simple_cfg = './compressed.cfg' record_file = './tmp/record.txt' compressed_retrain_model = amct.create_compressed_retrain_model( model=ori_model, input_data=ori_model_input_data, config_defination=simple_cfg, record_file=record_file)
save_path = '/.result/user_model' amct.save_compressed_retrain_model( model=compressed_retrain_model, record_file=record_file, save_path=save_path, input_data=ori_model_input_data)
compressed_model = './results/user_model_fake_quant_model.onnx' user_do_inference_onnx(compressed_model, test_data, test_iterations)
如果训练过程中断,需要从ckpt中恢复数据,继续训练,则调用流程为:
import amct_pytorch as amct
ori_model = user_create_model()
simple_cfg = './compressed.cfg' record_file = './tmp/record.txt' compressed_pth_file = './ckpt/user_model_newest.ckpt' compressed_retrain_model = amct.restore_compressed_retrain_model( model=ori_model, input_data=ori_model_input_data, config_defination=simple_cfg, record_file=record_file, pth_file=compressed_pth_file)
save_path = '/.result/user_model' amct.save_compressed_retrain_model( model=compressed_retrain_model, record_file=record_file, save_path=save_path, input_data=ori_model_input_data)
compressed_model = './results/user_model_fake_quant_model.onnx' user_do_inference_onnx(compressed_model, test_data, test_iterations)