Calibrator
功能说明
量化参数配置类,通过Calibrator类封装量化算法。
函数原型
Calibrator(model, cfg, calib_data=None, fuse_module_call_back=None)
参数说明
参数名 |
输入/返回值 |
含义 |
使用限制 |
|---|---|---|---|
model |
输入 |
待量化模型实例。 |
必选。 数据类型:PyTorch模型。 |
cfg |
输入 |
已配置的QuantConfig类。 |
必选。 数据类型:QuantConfig。 |
calib_data |
输入 |
模型训练数据,可输入真实数据用于Label-Free量化,也可输入虚拟数据来实现Data-Free量化。 |
可选。 数据类型:list[list[Torch.Tensor]] 或list[Torch.Tensor]。
|
fuse_module_call_back |
输入 |
BN融合用户自定义函数,在量化前会调用该回调。 |
可选。 数据类型:function。 如果模型结构特殊,不是conv->bn并列结构的,需要用户传入自定义融合函数。 |
调用示例
from modelslim.pytorch.quant.ptq_tools import QuantConfig, Calibrator
disable_names = []
input_shape = [1, 3, 224, 224]
quant_config = QuantConfig(disable_names=disable_names, amp_num=0, input_shape=input_shape)
calib_data = []
image = cv2.imdecode(np.fromfile("./random_image.jpg", dtype=np.uint8), 1)
image = cv2.resize(image, (224, 224,), interpolation=cv2.INTER_CUBIC)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = torch.from_numpy(image).permute(2, 0, 1)/255
image = image.unsqueeze(0)
calib_data.append([image]) #传入一张随机图片数据,用于提高精度
calibrator = Calibrator(model, quant_config, calib_data=calib_data)
父主题: 训练后量化(PyTorch)