开发者
资源

algorithm_register

产品支持情况

产品

是否支持

Atlas 350 加速卡

Atlas A3 训练系列产品/Atlas A3 推理系列产品

Atlas A2 训练系列产品/Atlas A2 推理系列产品

Atlas 200I/500 A2 推理产品

x

Atlas 推理系列产品

x

Atlas 训练系列产品

x

功能说明

将用户提供的自定义算法注册到AMCT工具

函数原型

1
algorithm_register(name, src_op, quant_op, deploy_op)

参数说明

参数名

输入/输出

说明

name

输入

含义:算法名称。

数据类型:string。

src_op

输入

含义:替换的算子。

数据类型:string。

quant_op

输入

含义:量化算子。

数据类型:torch.nn.Module。

deploy_op

输入

含义:部署算子。

数据类型:torch.nn.Module。

返回值说明

约束说明

调用示例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# 自定义算法名称
name = 'customize_algo'
# 需要量化的算子类型
src_op = 'Linear'
# 用户自己实现的量化算子
class CustomizedQuantOp(BaseQuantizeModule):
    def __init__(self,
                 ori_module,
                 layer_name,
                 quant_config):
        super().__init__(ori_module, layer_name, quant_config)

    @torch.no_grad()
    def forward(self, inputs):
        return
quant_op = CustomizedQuantOp
# 用户自己实现的部署算子
class CustomizedDeployOp(torch.nn.Module):
    def __init__(self, quant_module):
        super().__init__()

    def forward(self, x):
        return
deploy_op = CustomizedDeployOp
# 注册自定义算法
algorithm_register(name, src_op, quant_op, deploy_op)