量化感知训练当前仅支持INT8的基础量化:INT8量化是用8比特的INT8数据来表示32比特的FP32数据,将FP32的卷积运算过程(乘加运算)转换为INT8的卷积运算,加速运算和实现模型压缩。量化示例请参见获取更多样例>resnet_v1_50。量化感知训练支持量化的层以及约束如下:
层名 |
约束 |
---|---|
MatMul:全连接层 |
transpose_a=False, transpose_b=False,adjoint_a=False,adjoint_b=False |
Conv2D:卷积层 |
由于硬件约束,原始模型中输入通道数Cin<=16时不建议进行量化感知训练,否则可能会导致量化后的部署模型推理时精度下降 |
DepthwiseConv2dNative:Depthwise卷积层 |
|
Conv2DBackpropInput:反卷积层 |
|
AvgPool:平均下采样层 |
- |
量化感知训练接口调用流程如图1所示。
optimizer = tf.compat.v1.train.RMSPropOptimizer( ARGS.learning_rate, momentum=ARGS.momentum) train_op = optimizer.minimize(loss)
with tf.Session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) sess.run(outputs) #将训练后的参数保存为checkpoint文件 saver_save.save(sess, retrain_ckpt, global_step=0)
variables_to_restore = tf.compat.v1.global_variables() saver_restore = tf.compat.v1.train.Saver(variables_to_restore) with tf.Session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) #恢复训练参数 saver_restore.restore(sess, retrain_ckpt) #将量化因子写入record文件 sess.run(retrain_ops[-1]) #固化pb模型 constant_graph = tf.compat.v1.graph_util.convert_variables_to_constants( sess, eval_graph.as_graph_def(), [output.name[:-2] for output in outputs]) with tf.io.gfile.GFile(frozen_quant_eval_pb, 'wb') as f: f.write(constant_graph.SerializeToString())
import amct_tensorflow as amct amct.set_logging_level(print_level="info", save_level="info")
推荐执行该步骤,以确保原始模型可以完成推理且精度正常;执行该步骤时,可以使用部分测试集,减少运行时间。
user_test_evaluate_model(evaluate_model, test_data)
train_graph = user_load_train_graph()
config_file = './tmp/config.json' simple_cfg = './retrain.cfg' amct.create_quant_retrain_config(config_file=config_file, graph=train_graph, config_defination=simple_cfg)
record_file = './tmp/record.txt' retrain_ops = amct.create_quant_retrain_model(graph=train_graph, config_file=config_file, record_file=record_file)
optimizer = user_create_optimizer(train_graph)
推理和训练要在同一session中,推理执行的是retrain_ops[-1].output_tensor。
user_infer_graph(train_graph, retrain_ops[-1].output_tensor)
user_save_graph_to_pb(train_graph, trained_pb)
quant_model_path = './result/user_model' amct.save_quant_retrain_model(pb_model=trained_pb, outputs=user_model_outputs, record_file=record_file, save_path=quant_model_path)
quant_model = './results/user_model_quantized.pb' user_do_inference(quant_model, test_data)