产品 |
是否支持 |
---|---|
√ |
|
√ |
|
√ |
|
√ |
构造Linear的QAT算子。
1 | qat = amct_pytorch.nn.module.quantization.linear.LinearQAT(in_features, out_features, bias, device, dtype, config) |
1 | qat = amct_pytorch.nn.module.quantization.linear.LinearQAT.from_float(mod, config) |
参数名 |
输入/输出 |
说明 |
---|---|---|
in_features |
输入 |
含义:输入特征数。 数据类型:int |
out_features |
输入 |
含义:输出特征数。 数据类型:int |
bias |
输入 |
含义:是否开启偏值项参与学习。 数据类型:bool,其他数据类型(比如整数,字符串,列表等)按照Python真值判断规则转换。 默认值为True |
device |
输入 |
含义:运行设备。 默认值:None |
dtype |
输入 |
含义:torch数值类型。 数据类型:torch数据类型, 仅支持torch.float32 |
config |
输入 |
含义:量化配置,配置参考样例如下,量化配置参数的具体说明请参见量化配置参数说明。 config = { "retrain_enable":true, "retrain_data_config": { "dst_type": "INT8", "batch_num": 10, "fixed_min": False, "clip_min": -1.0, "clip_max": 1.0 }, "retrain_weight_config": { "dst_type": "INT8", "weights_retrain_algo": "arq_retrain", "channel_wise": False } } 数据类型:dict 默认值:None |
参数名 |
输入/输出 |
说明 |
---|---|---|
mod |
输入 |
含义:待量化的原生Linear算子。 数据类型:torch.nn.Module |
config |
输入 |
含义:量化配置。配置参考样例如下,量化配置参数的具体说明请参见量化配置参数说明。 config = { "retrain_enable":true, "retrain_data_config": { "dst_type": "INT8", "batch_num": 10, "fixed_min": False, "clip_min": -1.0, "clip_max": 1.0 }, "retrain_weight_config": { "dst_type": "INT8", "weights_retrain_algo": "arq_retrain", "channel_wise": False } } 数据类型:dict 默认值:None |
1 2 3 4 | from amct_pytorch.nn.module.quantization.linear import LinearQAT LinearQAT(in_features=1, out_features=1, bias=True, device=None, dtype=None, config=None) |
1 2 3 4 5 6 | import torch from amct_pytorch.nn.module.quantization.linear import LinearQAT linear_op = torch.nn.Linear(in_features=1, out_features=1, bias=True, device=None, dtype=None) LinearQAT.from_float(mod=linear_op, config=None) |