Manual Quantization
This section describes the supported quantization layers of PTQ, API calling process, and examples.
The layers that support PTQ are listed as follows. For details about the quantization workflow, see Sample List.
Supported Layer Type |
Restriction |
Remarks |
|---|---|---|
MatMul |
transpose_a = False, transpose_b = False, adjoint_a = False, adjoint_b = False |
- |
BatchMatMul/BatchMatMulV2 |
adjoint_a=False, adjoint_b=False |
- |
Conv2D |
- |
The weights are of type const and do not have dynamic inputs (such as placeholders). |
DepthwiseConv2dNative |
dilation=1 |
|
Conv2DBackpropInput |
dilation=1 |
|
AvgPool |
- |
- |
API Call Sequence
Figure 1 shows the API call sequence for PTQ.
- NPU-based online inference
- Build a source TensorFlow model and then generate a quantization configuration file by using the create_quant_config call.
- Optimize the source TensorFlow model using the quantize_model_ascend API based on the quantization configuration file. The optimized model contains quantization algorithms. Run online inference on the test and calibration datasets provided by AMCT in the NPU environment to obtain the quantization factors.
The test dataset is used to test the accuracy of the quantized model in the TensorFlow environment, while the calibration dataset is used to generate quantization factors to ensure accuracy.
- Call the save_model_ascend API to save the quantized model, which is deployable in the NPU environment.
- Inference using TensorFlow CPU:
- Scenario 1:
- Build a source TensorFlow model and then generate a quantization configuration file by using the create_quant_config call.
- Optimize the source TensorFlow model using the quantize_model API based on the quantization configuration file. The optimized model contains quantization algorithms. Run inference on the test and calibration datasets provided by AMCT in the TensorFlow (CPU version) environment to obtain the quantization factors.
The test dataset is used to test the accuracy of the quantized model in the TensorFlow environment, while the calibration dataset is used to generate quantization factors to ensure accuracy.
- Call the save_model API to save the quantized model, which can be used for accuracy simulation in the TensorFlow (CPU version) environment.
- Scenario 2:
If you have generated a quantized model using your own quantization factors and source TensorFlow model, instead of using the APIs in scenario 1, complete the quantization by using the convert_model API.
- Scenario 1:
Examples
The PTQ workflow goes through the following steps:
- Prepare an already-trained model and necessary datasets.
- Validate the model accuracy and environment setup in the source TensorFlow environment.
- Write a PTQ script based on AMCT API calls.
- Run the PTQ script.
- Test the accuracy of the fake-quantized model in the source Caffe environment.
The following details how to write a quantization script based on AMCT API calls.
- Take the following steps to get started. Update the sample code based on your situation.
- Tweak the arguments passed to AMCT API calls as required.
- Import the AMCT package and call the set_logging_level API to set the log level.
1 2
import amct_tensorflow as amct amct.set_logging_level(print_level="info", save_level="info")
- (Optional) Run inference on the source TensorFlow model in the NPU environment based on the test dataset to validate the inference script and environment setup. Update the sample code based on your situation.
When performing this step, pay attention to the following points:
- You are advised to perform this step to ensure that the original model can be inferred on the NPU with normal accuracy. If the inference fails, quantization cannot be performed. If the inference accuracy in this step does not meet the requirements, the subsequent quantization accuracy result will be unreliable.
- When performing this step, you can use some test sets to reduce the running time.
1user_do_inference_on_npu(ori_model, test_data)
- Prepare a tf.Graph based on the user_model.pb model file. (Update the sample code based on your situation.)
1 2
ori_model = 'user_model.pb' ori_graph = user_load_graph(ori_model)
- Run AMCT to quantize the model.
- Generate a quantization configuration file.
1 2 3 4 5
config_file = './tmp/config.json' skip_layers = [] amct.create_quant_config_ascend(config_file=config_file, graph=ori_graph, skip_layers=skip_layers)
- Modify the graph and insert quantization operators into the graph.
1 2 3 4 5 6 7
record_file = './tmp/record.txt' user_model_outputs = ['user_model_outputs0', 'user_model_outputs1'] calibration_graph, calibration_outputs = amct.quantize_model_ascend( graph=ori_graph, config_file=config_file, record_file=record_file, outputs=user_model_outputs)
- Run inference on the modified graph based on the calibration dataset to determine the quantization factors. Update the sample code based on your situation.
Pay attention to the following points:
- Ensure that the calibration dataset and the preprocessed data match the model to preserve the accuracy.
- The output of the quantized graph is calibration_outputs, which must be executed during online inference.
- The number of forward passes is specified by batch_num. If the number of forward passes is insufficient, the quantization factor is not output to the record file. As a result, the record file fails to be read for verification.
1user_do_inference_on_npu(calibration_graph, calibration_outputs, calibration_data)
- Save the model.
1 2 3 4 5
quant_model_path = './results/user_model' amct.save_model_ascend(pb_model=ori_model, outputs=user_model_outputs, record_file=record_file, save_path=quant_model_path)
- Generate a quantization configuration file.
- (Optional) Run inference on the fake-quantized model user_model_quantized.pb in the TensorFlow environment based on the test dataset to test the accuracy. (Update the sample code based on your situation.)
Compare the accuracy of the fake-quantized model with that of the source model (see 2).
1 2
quant_model = './results/user_model_quantized.pb' user_do_inference_on_cpu(quant_model, test_data)
