模型蒸馏接口,将用户提供教师模型、学生模型根据蒸馏配置进行组合,返回一个DistillDualModels实例,用户对DistillDualModels 实例进行训练。
由于PyTorch、MindSpore下蒸馏实现存在差异,对DistillDualModels实例的使用也存在如下区别。
get_distill_model(teacher, student, config)
参数名 |
输入/返回值 |
含义 |
使用限制 |
---|---|---|---|
teacher |
输入 |
教师模型。 |
必选。 数据类型:MindSpore模型或PyTorch模型。 |
student |
输入 |
学生模型。 |
必选。 数据类型:MindSpore模型或PyTorch模型。 |
config |
输入 |
蒸馏的配置。 |
必选。 数据类型:KnowledgeDistillConfig对象。 |
from msmodelslim.common.knowledge_distill.knowledge_distill import KnowledgeDistillConfig, get_distill_model #定义配置 distill_config = KnowledgeDistillConfig() distill_config. set_hard_label (0.5, 0) \ .add_inter_soft_label({ 't_module': 'uniter.encoder.encoder.blocks.11.output', 's_module': 'uniter.encoder.encoder.blocks.5.output', 't_output_idx': 0, 's_output_idx': 0, "loss_func": [{"func_name": "KDCrossEntropy", "func_weight": 1}], 'shape': [2048] }) #传入参数,返回蒸馏模型 distill_model = get_distill_model(teacher_model, student_model, distill_config)