quantize_model_ascend

Applicability

Product

Supported

Atlas A3 training series products/Atlas A3 inference series products

x

Atlas A2 training products/Atlas A2 inference products

x

Atlas 200I/500 A2 inference product

x

Atlas inference series products

x

Atlas training products

Description

Quantizes a graph based on the quantization configuration file, inserts the quantization operators, generates a quantization factor record file record_file, and returns the list of newly added operators.

Prototype

1
calibration_graph, calibration_outputs=quantize_model_ascend(graph, outputs, config_file, record_file)

Parameters

Parameter

Input/Output

Description

graph

Input

tf.Graph of the model to be quantized.

A tf.Graph.

Restrictions: Must be an inference graph containing no training-mode operators. For example, is_training of the FusedBatchNormV3 operator must be set to False. The graph is loaded with trained weights.

outputs

Input

List of output operators in a graph.

A list of strings.

config_file

Input

User-generated quantization configuration file, which specifies the configuration of each layer to be quantized in the tf.Graph.

A string.

record_file

Input

Path (including the file name) of the quantization factor record file.

A string.

Returns

  • calibration_graph: tf.Graph modified by the tool. Quantization operators are inserted into the graph.
  • calibration_outputs: list of quantized layer names in calibration_graph. The type is list. The element type in the list is string.

    The quantize_model_ascend API performs BN fusion on the graph. If the outputs of the network model contain the BN layer and the BN layer is also fused, the output node of the network changes. For example, Conv+BN (or Conv+BiasAdd+BN) is fused into Conv+BiasAdd, and an output node equivalent to BN is a BiasAdd node.

Example

1
2
3
4
5
6
7
8
9
import amct_tensorflow as amct
# Build a network to be quantized.
network = build_network()

# Insert the quantization API.
calibration_graph, calibration_outputs = amct.quantize_model_ascend(
      graph=tf.get_default_graph(),
      config_file="./configs/config.json",
      record_file="./record_scale_offset.txt")