msModelSlim工具提供了API方式的Transformer类模型权重剪枝调优,可将模型权重进行裁剪,并加载到同一模型结构下的小模型中。用户只需提供同一模型结构下小模型(通过配置较小初始化参数得到的模型实例,例如Bert模型中缩小intermediate_size和num_hidden_layers参数)和原始模型权重文件,即可调用剪枝API完成模型权重的剪枝。
目前支持MindSpore和PyTorch框架下Transformer类模型的剪枝调优,执行剪枝调优前需参考环境准备完成开发环境部署、Python环境变量、所需框架及训练服务器环境变量配置。
模型剪枝期间,用户可手动配置参数对预训练模型的权重进行裁剪,并将裁剪后的权重加载至小模型中,获取一个权重加载完毕的Transformer模型。剪枝后模型不保障精度,需要进行一定的训练来提升精度,例如通过模型蒸馏进行训练。
以下步骤以PyTorch框架的Transformer类模型为例,MindSpore框架的模型仅在调用部分接口时,入参配置有所差异,使用时请参照具体的API接口说明。
from msmodelslim.common.prune.transformer_prune.prune_model import PruneConfig from msmodelslim.common.prune.transformer_prune.prune_model import prune_model_weight
from msmodelslim import set_logger_level set_logger_level("info") #根据实际情况配置
prune_config = PruneConfig() prune_config.set_steps(['prune_blocks', 'prune_bert_intra_block']). \ add_blocks_params(pattern="bert.encoder.layer.(\d+).",layer_id_map={0: 0, 1: 2, 2: 4, 3: 6, 4: 8, 5: 10, 6: 11})
若set_steps方法中配置的剪枝步骤包含“prune_blocks”,则必须调用“add_blocks_params”方法进行配置。
import modeling # 导入bert模型 bert_config = modeling.BertConfig.from_json_file(bert_config_file) # 载入bert配置,初始化较小的模型。 bert_model = modeling.BertForQuestionAnswering(bert_config) # 实例化bert模型 prune_model_weight(bert_model, prune_config, weight_file_path = "/home/xxx/xxx.pt") #model根据实际情况配置待剪枝模型实例,weight_file_path根据实际情况配置原模型的权重文件
MindSpore模型的权重文件需为ckpt格式,PyTorch框架的权重文件需为pt/pth/pkl/bin格式,具体请参考prune_model_weight进行配置。
python3 test_prune_model.py