msModelSlim工具支持API方式的蒸馏调优。蒸馏调优时,用户只需要提供teacher模型、student模型和数据集,调用API接口完成模型的蒸馏调优过程。
目前支持MindSpore和PyTorch框架下Transformer类模型的蒸馏调优,执行前需参考环境准备完成开发环境部署、Python环境变量、所需框架及训练服务器环境变量配置。
模型蒸馏期间,用户可将原始Transformer模型、配置较小参数的Transformer模型分别作为teacher和student进行知识蒸馏。通过手动配置参数,返回一个待蒸馏的DistillDualModels模型实例,用户对其进行训练。训练完毕后,从DistillDualModels模型实例获取训练后的student模型,即通过蒸馏训练后的模型。
以下步骤以PyTorch框架的模型为例,MindSpore框架的模型仅在调用部分接口时,入参配置有所差异,使用时请参照具体的API接口说明。
from msmodelslim.common.knowledge_distill.knowledge_distill import KnowledgeDistillConfig, get_distill_model
from msmodelslim import set_logger_level set_logger_level("info") #根据实际情况配置
distill_config = KnowledgeDistillConfig() distill_config.add_output_soft_label({ "t_output_idx": 1, "s_output_idx": 1, "loss_func": [{"func_name": "KDCrossEntropy", "func_weight": 1, "temperature": 1}]})
distill_model = get_distill_model(teacher_model, student_model, distill_config) #请传入teacher模型、student模型的实例
将原始代码中model = modeling.BertForQuestionAnswering(config)改为model = distill_model.student_model,从而为student模型设置optimizer。
将原始代码中start_logits, end_logits = model(input_ids, segment_ids, input_mask)改为loss, student_outputs, teacher_outputs = distill_model (input_ids, segment_ids, input_mask),并注释原始的loss计算部分,从而对 DistillDualModels模型实例进行训练。
student_model = distill_model.get_student_model()
python3 distill_model.py