quantize_model_ascend

产品支持情况

产品

是否支持

Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件

x

Atlas 200I/500 A2 推理产品

x

Atlas 推理系列产品

x

Atlas 训练系列产品

功能说明

训练后量化接口,将输入的待量化的图结构按照给定的量化配置文件进行量化处理,在传入的图结构中插入量化相关的算子,生成量化因子记录文件record_file,并返回量化处理新增的算子列表。

函数原型

1
calibration_graph, calibration_outputs=quantize_model_ascend(graph, outputs, config_file, record_file)

参数说明

参数名

输入/输出

说明

graph

输入

含义:用户传入的待量化模型的tf.Graph图。

数据类型:tf.Graph

使用约束:graph必须为推理图,图中不能包含训练模式的算子,例如FusedBatchNormV3算子的is_training必须为False。图中必须加载了训练好的权重。

outputs

输入

含义:graph中输出算子的列表。

数据类型:list,列表中元素类型为string

config_file

输入

含义:用户生成的量化配置文件,用于指定模型tf.Graph图中量化层的配置情况。

数据类型:string

record_file

输入

含义:量化因子记录文件路径及名称。

数据类型:string

返回值说明

调用示例

1
2
3
4
5
6
7
8
9
import amct_tensorflow as amct
# 建立待量化的网络结构
network = build_network()

# 插入量化API
calibration_graph, calibration_outputs = amct.quantize_model_ascend(
      graph=tf.get_default_graph(),
      config_file="./configs/config.json",
      record_file="./record_scale_offset.txt")