add_custom_loss_func
功能说明
KnowledgeDistillConfig类方法,用户调用该方法增加自定义 loss function,而不只是使用api提供的loss function,非必须调用的方法。
本方法只能对loss function是否为MindSpore模型或PyTorch模型进行校验,不保证用户自定义loss function的可用性、正确性。
函数原型
add_custom_loss_func(name, instance)
参数说明
参数名 |
输入/返回值 |
含义 |
使用限制 |
|---|---|---|---|
name |
输入 |
自定义loss function名称。 |
必选。 数据类型:string。 |
instance |
输入 |
自定义loss function的实例。 |
可选。 数据类型:MindSpore模型或PyTorch模型。 |
调用示例
from msmodelslim.common.knowledge_distill.knowledge_distill import KnowledgeDistillConfig
#用户自定义loss function的实例
class CustomLoss(Cell):
def __init__(self):
# init
def construct(self, logits_s, logits_t):
# calculate loss by logits_s and logits_t
return loss
custom_loss = CustomLoss()
#定义配置
distill_config = KnowledgeDistillConfig()
distill_config.set_hard_label (0.5, 0) \
.add_custom_loss_func("custom_loss", custom_loss) \
.add_output_soft_label({
't_output_idx': 0,
's_output_idx': 0,
"loss_func": [{"func_name": "custom_loss",
"func_weight": 1}]
})