昇腾社区首页
EN
注册

Conv2dQAT

产品支持情况

产品

是否支持

Atlas A3 训练系列产品/Atlas A3 推理系列产品

Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件

Atlas 200I/500 A2 推理产品

Atlas 推理系列产品

Atlas 训练系列产品

功能说明

构造Conv2d的QAT算子。

函数原型

  • 直接构造接口:
    1
    qat = amct_pytorch.nn.module.quantization.conv2d.Conv2dQAT(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype, config)
    
  • 基于原生算子构造接口:
    1
    qat = amct_pytorch.nn.module.quantization.conv2d.Conv2dQAT.from_float(mod, config)
    

参数说明

表1 直接构造接口参数说明

参数名

输入/输出

说明

in_channels

输入

含义:输入channel个数。

数据类型:int

out_channels

输入

含义:输出channel个数、

数据类型:int

kernel_size

输入

含义:卷积核大小。

数据类型:int/tuple

stride

输入

含义:卷积步长。

数据类型:int/tuple

默认值:1

padding

输入

含义:填充大小。

数据类型:int/tuple

默认值:0

dilation

输入

含义:kernel元素之间的间距。

数据类型:int/tuple

默认值:1

groups

输入

含义:输入和输出的连接关系。

数据类型:int

默认值:1

bias

输入

含义:是否开启偏置项参与学习。

数据类型:bool,其他数据类型(比如整数,字符串,列表等)按照Python真值判断规则转换。

默认值:True

padding_mode

输入

含义:填充方式。

使用约束:仅支持zeros

device

输入

含义:运行设备。

默认值:None

dtype

输入

含义:torch数值类型。

torch数据类型,仅支持torch.float32

config

输入

含义:量化配置,配置参考样例如下,量化配置参数的具体说明请参见量化配置参数说明

config = {
    "retrain_enable":true,
    "retrain_data_config": {
        "dst_type": "INT8",
        "batch_num": 10,
        "fixed_min": False,
        "clip_min": -1.0,
        "clip_max": 1.0
    },
    "retrain_weight_config": {
        "dst_type": "INT8",
        "weights_retrain_algo": "arq_retrain",
        "channel_wise": False
    }
}

数据类型:dict

默认值:None

表2 基于原生算子构造接口

参数名

输入/输出

说明

mod

输入

含义:待量化的原生Conv2d算子

数据类型:torch.nn.Module

config

输入

含义:量化配置,配置参考样例如下,量化配置参数的具体说明请参见量化配置参数说明

config = {
    "retrain_enable":true,
    "retrain_data_config": {
        "dst_type": "INT8",
        "batch_num": 10,
        "fixed_min": False,
        "clip_min": -1.0,
        "clip_max": 1.0
    },
    "retrain_weight_config": {
        "dst_type": "INT8",
        "weights_retrain_algo": "arq_retrain",
        "channel_wise": False
    }
}

数据类型:dict

默认值:None

返回值说明

  • 直接构造:返回构造的QAT单算子实例。
  • 基于原生算子构造:torch.nn.Module转化后的QAT单算子。

调用示例

  • 直接构造:
    1
    2
    3
    4
    5
    from amct_pytorch.nn.module.quantization.conv2d import Conv2dQAT
    
    Conv2dQAT(in_channels=1, out_channels=1, kernel_size=1, stride=1,
              padding=0, dilation=1, groups=1, bias=True,
              padding_mode='zeros', device=None, dtype=None, config=None)
    
  • 基于原生算子构造:
    1
    2
    3
    4
    5
    6
    7
    8
    import torch
    
    from amct_pytorch.nn.module.quantization.conv2d import Conv2dQAT
    
    conv2d_op = torch.nn.Conv2d(in_channels=1, out_channels=1, kernel_size=1, stride=1,
                                padding=0, dilation=1, groups=1, bias=True,
                                padding_mode='zeros', device=None, dtype=None)
    Conv2dQAT.from_float(mod=conv2d_op, config=None)