msModelSlim工具提供了基于重要性评估的模型剪枝调优API,用户只需要提供模型实例,即可调用剪枝API完成模型的剪枝。剪枝后的模型提升了一定的性能,减少了模型的大小,提升推理过程中的效率。
目前支持PyTorch框架下的模型剪枝调优,执行剪枝调优前需参考环境准备完成开发环境部署、Python环境变量、PyTorch框架及训练服务器环境变量配置。
from msmodelslim.pytorch.prune.prune_torch import PruneTorch
from msmodelslim import set_logger_level set_logger_level("info") #根据实际情况配置
desc = PruneTorch(model, torch.ones([1, 3, 224, 224]).type(torch.float32)).prune(0.8)
python3 train.py --model vgg16 --lr 1e-5 --epochs 10 --pretrained --batch-size 256 -j 48
将获取一个剪枝后的模型,可以进行后续的训练任务。
PruneTorch(model, torch.ones([1, 3, 224, 224]).type(torch.float32)).prune_by_desc(desc)