LinearQAT
功能说明
构造Linear的QAT算子。
函数原型
直接构造接口:
amct_pytorch.nn.module.quantization.linear.LinearQAT(in_features, out_features, bias, device, dtype, config)
基于原生算子构造接口:
amct_pytorch.nn.module.quantization.linear.LinearQAT.from_float(mod, config)
参数说明
参数名 |
输入/输出 |
含义 |
使用限制 |
|---|---|---|---|
in_features |
输入 |
输入特征数 |
数据类型:int 必填 |
out_features |
输入 |
输出特征数 |
数据类型:int 必填 |
bias |
输入 |
是否开启偏值项参与学习 |
数据类型:bool 默认值为True |
device |
输入 |
运行设备 |
默认值:None |
dtype |
输入 |
torch数值类型 |
torch数据类型, 仅支持torch.float32 |
config |
输入 |
量化配置,配置参考样例如下,量化配置参数的具体说明请参见量化配置参数说明。 config = {
"retrain_enable":true,
"retrain_data_config": {
"dst_type": "INT8",
"batch_num": 10,
"fixed_min": False,
"clip_min": -1.0,
"clip_max": 1.0
},
"retrain_weight_config": {
"dst_type": "INT8",
"weights_retrain_algo": "arq_retrain",
"channel_wise": False
}
}
|
数据类型:dict 默认值:None |
参数名 |
输入/输出 |
含义 |
使用限制 |
|---|---|---|---|
mod |
输入 |
待量化的原生Linear算子 |
数据类型:torch.nn.Module |
config |
输入 |
量化配置。 配置参考样例如下,量化配置参数的具体说明请参见量化配置参数说明。 config = {
"retrain_enable":true,
"retrain_data_config": {
"dst_type": "INT8",
"batch_num": 10,
"fixed_min": False,
"clip_min": -1.0,
"clip_max": 1.0
},
"retrain_weight_config": {
"dst_type": "INT8",
"weights_retrain_algo": "arq_retrain",
"channel_wise": False
}
}
|
数据类型:dict 默认值:None |
返回值说明
生成一个Linear对应QAT算子,用于后续量化感知训练。
调用示例
直接构造:
1 2 3 4 | from amct_pytorch.nn.module.quantization.linear import LinearQAT LinearQAT(in_features=1, out_features=1, bias=True, device=None, dtype=None, config=None) |
基于原生算子构造:
1 2 3 4 5 6 | import torch from amct_pytorch.nn.module.quantization.linear import LinearQAT linear_op = torch.nn.Linear(in_features=1, out_features=1, bias=True, device=None, dtype=None) LinearQAT.from_float(mod=linear_op, config=None) |
父主题: 单算子模式