Conv2dQAT
产品支持情况
| 产品 | 是否支持 | 
|---|---|
| √ | |
| √ | |
| √ | |
| √ | |
| √ | 
功能说明
构造Conv2d的QAT算子。
函数原型
- 直接构造接口:1qat = amct_pytorch.nn.module.quantization.conv2d.Conv2dQAT(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype, config) 
- 基于原生算子构造接口:1qat = amct_pytorch.nn.module.quantization.conv2d.Conv2dQAT.from_float(mod, config) 
参数说明
| 参数名 | 输入/输出 | 说明 | 
|---|---|---|
| in_channels | 输入 | 含义:输入channel个数。 数据类型:int | 
| out_channels | 输入 | 含义:输出channel个数、 数据类型:int | 
| kernel_size | 输入 | 含义:卷积核大小。 数据类型:int/tuple | 
| stride | 输入 | 含义:卷积步长。 数据类型:int/tuple 默认值:1 | 
| padding | 输入 | 含义:填充大小。 数据类型:int/tuple 默认值:0 | 
| dilation | 输入 | 含义:kernel元素之间的间距。 数据类型:int/tuple 默认值:1 | 
| groups | 输入 | 含义:输入和输出的连接关系。 数据类型:int 默认值:1 | 
| bias | 输入 | 含义:是否开启偏置项参与学习。 数据类型:bool,其他数据类型(比如整数,字符串,列表等)按照Python真值判断规则转换。 默认值:True | 
| padding_mode | 输入 | 含义:填充方式。 使用约束:仅支持zeros | 
| device | 输入 | 含义:运行设备。 默认值:None | 
| dtype | 输入 | 含义:torch数值类型。 torch数据类型,仅支持torch.float32 | 
| config | 输入 | 含义:量化配置,配置参考样例如下,量化配置参数的具体说明请参见量化配置参数说明 config = {
    "retrain_enable":true,
    "retrain_data_config": {
        "dst_type": "INT8",
        "batch_num": 10,
        "fixed_min": False,
        "clip_min": -1.0,
        "clip_max": 1.0
    },
    "retrain_weight_config": {
        "dst_type": "INT8",
        "weights_retrain_algo": "arq_retrain",
        "channel_wise": False
    }
}数据类型:dict 默认值:None | 
| 参数名 | 输入/输出 | 说明 | 
|---|---|---|
| mod | 输入 | 含义:待量化的原生Conv2d算子 数据类型:torch.nn.Module | 
| config | 输入 | 含义:量化配置,配置参考样例如下,量化配置参数的具体说明请参见量化配置参数说明 config = {
    "retrain_enable":true,
    "retrain_data_config": {
        "dst_type": "INT8",
        "batch_num": 10,
        "fixed_min": False,
        "clip_min": -1.0,
        "clip_max": 1.0
    },
    "retrain_weight_config": {
        "dst_type": "INT8",
        "weights_retrain_algo": "arq_retrain",
        "channel_wise": False
    }
}数据类型:dict 默认值:None | 
返回值说明
- 直接构造:返回构造的QAT单算子实例。
- 基于原生算子构造:torch.nn.Module转化后的QAT单算子。
调用示例
- 直接构造:1 2 3 4 5 from amct_pytorch.nn.module.quantization.conv2d import Conv2dQAT Conv2dQAT(in_channels=1, out_channels=1, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None, config=None) 
- 基于原生算子构造:1 2 3 4 5 6 7 8 import torch from amct_pytorch.nn.module.quantization.conv2d import Conv2dQAT conv2d_op = torch.nn.Conv2d(in_channels=1, out_channels=1, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None) Conv2dQAT.from_float(mod=conv2d_op, config=None)