create_quant_cali_model

产品支持情况

产品

是否支持

Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件

Atlas 200I/500 A2 推理产品

x

Atlas 推理系列产品

Atlas 训练系列产品

x

注:标记“x”的产品,调用接口不会报错,但是获取不到性能收益。

功能说明

KV-cache量化接口,根据模型和量化详细配置,对用户模型进行改图,将待量化Linear算子替换为输出后进行IFMR/HFMG量化的量化算子,后续用户拿到模型后进行在线校准,校准后生成量化因子保存在record_file中。

函数原型

1
calibration_model = create_quant_cali_model(config_file, record_file, model)

参数说明

参数名

输入/输出

说明

config_file

输入

含义:生成的量化配置文件路径,配置文件为json格式。

数据类型:string

使用约束:该接口输入的config.json必须和create_quant_cali_config接口输入的config.json一致

record_file

输入

含义:在线校准量化因子保存的路径及文件名称。

数据类型:string

model

输入

含义:用户提供的待量化模型。

数据类型:torch.nn.module

返回值说明

替换为校准算子的量化校准模型。

调用示例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
import amct_pytorch as amct
# 建立待进行量化的网络图结构
model = build_model()
model.load_state_dict(torch.load(state_dict_path))

record_file = os.path.join(TMP, 'kv_cache.txt')
# 插入量化API,生成量化校准模型
calibration_model = amct.create_quant_cali_model(
                    config_file="./configs/config.json",  # 生成的量化因子记录文件
                    record_file,                   
                    model)