Conv2dQAT
功能说明
构造Conv2d的QAT算子。
函数原型
直接构造接口:
amct_pytorch.nn.module.quantization.conv2d.Conv2dQAT(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype, config)
基于原生算子构造接口:
amct_pytorch.nn.module.quantization.conv2d.Conv2dQAT.from_float(mod, config)
参数说明
| 参数名 | 输入/输出 | 含义 | 使用限制 | 
|---|---|---|---|
| in_channels | 输入 | 输入channel个数 | 数据类型:int 必填 | 
| out_channels | 输入 | 输出channel个数 | 数据类型:int 必填 | 
| kernel_size | 输入 | 卷积核大小 | 数据类型:int/tuple 必填 | 
| stride | 输入 | 卷积步长 | 数据类型:int/tuple 默认值为1 | 
| padding | 输入 | 填充大小 | 数据类型:int/tuple 默认值为0 | 
| dilation | 输入 | kernel元素之间的间距 | 数据类型:int/tuple 默认值为1 | 
| groups | 输入 | 输入和输出的连接关系 | 数据类型:int 默认值为1 | 
| bias | 输入 | 是否开启偏值项参与学习 | 数据类型:bool 默认值为True | 
| padding_mode | 输入 | 填充方式 | 仅支持zeros | 
| device | 输入 | 运行设备 | 默认值:None | 
| dtype | 输入 | torch数值类型 | torch数据类型, 仅支持torch.float32 | 
| config | 输入 | 量化配置。 配置参考样例如下,量化配置参数的具体说明请参见量化配置参数说明。 config = {
    "retrain_data_config": {
        "dst_type": "INT8",
        "batch_num": 10,
        "fixed_min": False,
        "clip_min": -1.0,
        "clip_max": 1.0
    },
    "retrain_weight_config": {
        "dst_type": "INT8",
        "weights_retrain_algo": "arq_retrain",
        "channel_wise": False
    }
} | 数据类型:dict 默认值:None | 
| 参数名 | 输入/输出 | 含义 | 使用限制 | 
|---|---|---|---|
| mod | 输入 | 待量化的原生Conv2d算子 | 数据类型:torch.nn.Module | 
| config | 输入 | 量化配置。 配置参考样例如下,量化配置参数的具体说明请参见量化配置参数说明。 config = {
    "retrain_data_config": {
        "dst_type": "INT8",
        "batch_num": 10,
        "fixed_min": False,
        "clip_min": -1.0,
        "clip_max": 1.0
    },
    "retrain_weight_config": {
        "dst_type": "INT8",
        "weights_retrain_algo": "arq_retrain",
        "channel_wise": False
    }
} | 数据类型:dict 默认值:None | 
返回值说明
生成一个Conv2d对应QAT算子,用于后续量化感知训练。
调用示例
直接构造:
| 1 2 3 4 5 | from amct_pytorch.nn.module.quantization.conv2d import Conv2dQAT Conv2dQAT(in_channels=1, out_channels=1, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None, config=None) | 
基于原生算子构造:
| 1 2 3 4 5 6 7 8 | import torch from amct_pytorch.nn.module.quantization.conv2d import Conv2dQAT conv2d_op = torch.nn.Conv2d(in_channels=1, out_channels=1, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None) Conv2dQAT.from_float(mod=conv2d_op, config=None) |