API功能:计算输入
[object Object]和标签[object Object]之间的交叉熵损失。此API将原生[object Object]中的log_softmax和nll_loss融合,降低计算时使用的内存。计算公式:
公式中x是输入
[object Object],y 是标签[object Object],weight是权重,C 是标签数,N 是批处理大小。交叉熵损失
[object Object]的计算公式:其中的计算公式为:
第n个样本对第c个类别的对数概率
[object Object]计算公式为:
[object Object]
- input(
[object Object]): 必选参数,表示输入,对应公式中x;数据类型支持[object Object]、[object Object]、[object Object];shape为[N, C],N为批处理大小,C为标签数,必须大于0。 - target(
[object Object]): 必选参数,表示标签,对应公式中y;数据类型支持[object Object];shape为[N],与[object Object]第零维相同,取值范围[0, C)。 - weight(
[object Object]): 可选参数,表示每个类别指定的缩放权重;数据类型支持[object Object];shape为[C],与[object Object]第二维相同,取值范围(0, 1],不指定值时默认为全一。 - reduction(
[object Object]): 可选参数,表示loss的归约方式;支持范围["mean", "sum", "none"],[object Object]表示平均归约,[object Object]表示求和归约,[object Object]表示无归约,默认为[object Object]。 - ignore_index(
[object Object]): 可选参数,表示指定忽略的标签;数值必须小于C,当小于0时表示不指定忽略标签;默认值为-100。 - label_smoothing(
[object Object]): 可选参数,表示计算loss时的平滑量;取值范围[0.0, 1.0);默认值为0.0。 - lse_square_scale_for_zloss(
[object Object]): 可选参数,表示计算zloss所需要的scale;取值范围[0.0, 1.0);默认值为0.0;当前暂不支持。 - return_zloss(
[object Object]): 可选参数,控制是否返回zloss;设置为[object Object]时返回zloss,设置为[object Object]时不返回zloss;默认值为[object Object];当前暂不支持。
- loss(
[object Object]):表示输出损失;数据类型与[object Object]相同;[object Object]为[object Object]时shape为[N],与[object Object]第零维一致,否则shape为[1]。 - log_prob(
[object Object]): 表示给反向计算的输出;数据类型与[object Object]相同;shape为[N, C],与[object Object]一致。 - zloss(
[object Object]): 表示辅助损失;数据类型与[object Object]相同;shape与[object Object]一致;当[object Object]为[object Object]时输出zloss,否则将返回空tensor;当前暂不支持。 - lse_for_zloss(
[object Object]):在zloss场景下给反向计算的输出;数据类型与[object Object]相同;shape为[N],与[object Object]第零维一致;[object Object]不为[object Object]时将返回该输出,否则将返回空tensor;当前暂不支持。
- 输入shape中N取值范围(0, 200000]。
- 当input.requires_grad=True时,
[object Object]/[object Object]模式下不支持修改[object Object]的默认值;[object Object]模式下只支持传入可选参数的默认值,包括[object Object],[object Object]和[object Object]。 - 输入
[object Object]与[object Object]暂未使能。 - 输出
[object Object]与[object Object]暂未使能。 - 输出中仅
[object Object]支持梯度计算。
当reduction设置为
[object Object]时,示例如下:[object Object]当reduction设置为
[object Object]时,示例如下:[object Object]