beta)torch_npu.npu_cross_entropy_loss

接口原型

torch_npu.npu_cross_entropy_loss(Tensor input, Tensor target, Tensor? weight=None, str reduction="mean", int ignore_index=-100, float label_smoothing=0.0, float lse_square_scale_for_zloss=0.0, bool return_zloss=False) -> (Tensor, Tensor, Tensor, Tensor)

功能描述

将原生CrossEntropyLoss中的log_softmax和nll_loss融合,降低计算时使用的内存。接口允许计算zloss。

参数说明

输出说明

约束说明

支持的型号

调用示例

1
2
3
4
5
6
7
8
9
import torch
import torch_npu
 
N = 4096
C = 8080
input = torch.randn(N, C).npu()
target = torch.arange(0, N).npu()
 
loss, log_prob, _, _ = torch_npu.npu_cross_entropy_loss(input, target)