KnowledgeDistillConfig类方法,用户调用该方法增加自定义 loss function,而不只是使用api提供的loss function,非必须调用的方法。
本方法只能对loss function是否为MindSpore模型或PyTorch模型进行校验,不保证用户自定义loss function的可用性、正确性。
add_custom_loss_func(name, instance)
参数名 |
输入/返回值 |
含义 |
使用限制 |
---|---|---|---|
name |
输入 |
自定义loss function名称。 |
必选。 数据类型:string。 |
instance |
输入 |
自定义loss function的实例。 |
可选。 数据类型:MindSpore模型或PyTorch模型。 |
from msmodelslim.common.knowledge_distill.knowledge_distill import KnowledgeDistillConfig #用户自定义loss function的实例 class CustomLoss(Cell): def __init__(self): # init def construct(self, logits_s, logits_t): # calculate loss by logits_s and logits_t return loss custom_loss = CustomLoss() #定义配置 distill_config = KnowledgeDistillConfig() distill_config.set_hard_label (0.5, 0) \ .add_custom_loss_func("custom_loss", custom_loss) \ .add_output_soft_label({ 't_output_idx': 0, 's_output_idx': 0, "loss_func": [{"func_name": "custom_loss", "func_weight": 1}] })