auto_channel_prune_search

产品支持情况

产品

是否支持

Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件

Atlas 200I/500 A2 推理产品

Atlas 推理系列产品

Atlas 训练系列产品

功能说明

自动通道稀疏接口,根据用户模型来计算各通道的稀疏敏感度(影响精度)以及稀疏收益(影响性能),然后搜索策略依据该输入来搜索最优的逐层通道稀疏率,以平衡精度和性能。最终输出一个配置文件。

函数原型

1
auto_channel_prune_search(model, config, input_data, output_cfg, sensitivity, search_alg)

参数说明

参数名

输入/输出

说明

model

输入

含义:待稀疏的PyTorch模型。

数据类型:torch.nn.Module

config

输入

含义:自动通道稀疏配置文件路径。

基于basic_info.proto文件中的AutoChannelPruneConfig生成的简易配置文件,*.proto件所在路径为:AMCT安装目录/amct_pytorch/proto/。

*.proto文件参数解释以及生成的自动通道稀疏搜索配置文件样例请参见自动(混合精度或通道稀疏)搜索简易配置文件

数据类型:string

input_data

输入

含义:用户提供获取输入数据(含label)。

数据类型:list[data,label],列表元素数据类型为torch.tensor。

output_cfg

输入

含义:输出的最终的通道稀疏配置文件路径。

数据类型:string

sensitivity

输入

含义:敏感度计算方法。

数据类型:string或SensitivityBase的子类,string为amct已有的方法,目前可选为'TaylorLossSensitivity';SensitivityBase的子类实例化,可由用户来继承定义。

search_alg

输入

含义:待稀疏的通道搜索方法。

数据类型:string或SearchChannelBase的子类,string为amct已有的方法,目前可选为'GreedySearch';SearchChannelBase的子类实例化,可由用户来继承定义。

返回值说明

调用示例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
import amct_pytorch as amct
#构造输入数据input_data
input_data = torch.randn(input_shape)         
model.eval()        
output = model.forward(input_data)        
labels = torch.randn(output.size())        
data = [input_data,labels]

amct.auto_channel_prune_search(
     model=model,
     config='./tmp/sample.cfg',
     input_data=data,  
     output_cfg='./tmp/output.cfg', 
     sensitivity='TaylorLossSensitivity', 
     search_alg='GreedySearch')

落盘文件说明:

保存的自动通道稀疏配置文件,需要传给通道稀疏接口完成后续的业务。