distill

产品支持情况

产品

是否支持

Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件

Atlas 200I/500 A2 推理产品

Atlas 推理系列产品

Atlas 训练系列产品

功能说明

蒸馏接口,将输入的待蒸馏的图结构按照给定的蒸馏量化配置文件进行蒸馏处理,返回修改后的torch.nn.module蒸馏模型。

函数原型

1
distill_model = distill(model, compress_model, config_file, train_loader, epochs=1, lr=1e-3, sample_instance=None, loss=None, optimizer=None)

参数说明

参数名

输入/输出

说明

model

输入

含义:待进行蒸馏量化的原始浮点模型,已加载权重。

数据类型:torch.nn.module

compress_model

输入

含义:修改后的可用于蒸馏的torch.nn.module模型。

数据类型:torch.nn.module

使用约束:该接口输入的模型必须是量化后的压缩模型。

config_file

输入

含义:用户生成的蒸馏量化配置文件,用于指定模型network中量化层的配置情况和蒸馏结构。

数据类型:string

使用约束:该接口输入的config.json必须和create_distill_config接口输入的config.json一致。

train_loader

输入

含义:训练数据集。

数据类型:torch.utils.data.DataLoader

使用约束:必须与模型输入大小匹配。

epochs

输入

含义:最大迭代次数。

默认值:1

数据类型:int

lr

输入

含义:学习率。

默认值:1e-3

数据类型:float

sample_instance

输入

含义:用户提供的获取模型输入数据方法的实例化对象。

默认值:None

数据类型:DistillSampleBase

使用约束:必须继承自DistillSampleBase类,并且实现get_model_input_data方法。可参考AMCT安装目录/amct_pytorch/distill/distill_sample.py文件。

loss

输入

含义:用于计算损失的实例化对象。

默认值:None

数据类型:torch.nn.modules.loss._Loss

optimizer

输入

含义:优化器的实例化对象。

默认值:None

数据类型:torch.optim.Optimizer

返回值说明

修改后的torch.nn.module蒸馏模型。

调用示例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
import amct_pytorch as amct
# 建立待进行蒸馏量化的网络图结构
model = build_model()
model.load_state_dict(torch.load(state_dict_path))
compress_model = compress(model)
input_data = tuple([torch.randn(input_shape)])
train_loader = torch.utils.data.DataLoader(input_data)
loss = torch.nn.MSELoss()
optimizer = torch.optim.AdamW(compress_model.parameters(), lr=0.1)

# 蒸馏
distill_model = amct.distill(
                model,
                compress_model
                config_json_file,
                train_loader,
                epochs=1,
                lr=1e-3,
                sample_instance=None, 
                loss=loss,
                optimizer=optimizer)