AntiOutlier
功能说明
构建用于异常值的类,并将模型,异常值抑制config,校准数据等传入。
函数原型
AntiOutlier(model, calib_data, cfg: Config, dag=None, logger=None, model_type=None)
参数说明
| 参数名 | 输入/返回值 | 含义 | 使用限制 | 
|---|---|---|---|
| model | 输入 | 用于大模型离群值抑制的模型。 | 必选。 数据类型:PyTorch模型。 | 
| calib_data | 输入 | 用于离群值抑制的校准数据。 | 必选。 数据类型:object。 默认值为None。 输入模板:[[input1],[input2],[input3]]。 | 
| cfg | 输入 | 已配置的AntiOutlierConfig类。 | 可选。 数据类型:Config。 | 
| dag | 输入 | 模型图。 | 可选。 默认为None,采用默认配置即可。 | 
| logger | 输入 | Logger对象。 | 可选。 数据类型:object。 默认值为None,采用默认配置即可。 | 
| model_type | 输入 | 模型类型。 | 可选。 数据类型:object。 默认值为None。 
 | 
调用示例
from modelslim.pytorch.llm_ptq.anti_outlier import AntiOutlier, AntiOutlierConfig anti_config = AntiOutlierConfig(anti_method="m2") anti_outlier = AntiOutlier(model, calib_data=dataset_calib, cfg=anti_config, model_type='Llama') anti_outlier.process() calibrator = Calibrator(model, quant_config, calib_data=dataset_calib, disable_level='L0') calibrator.run(int_infer=False) calibrator.save(qaunt_weight_save_path)
父主题: 大模型量化接口