class LabelSmoothingCrossEntropy()
API接口
class LabelSmoothingCrossEntropy(nn.Module):
功能描述
使用NPU API进行LabelSmoothing Cross Entropy。
参数说明
- smooth_factor (Float,默认值为0) -如果正在使用LabelSmoothing,请改为0.1([0, 1])。
- num_classes (Float) - 用于onehot的class数量。
输出说明
Float - shape为(k, 5)和(k, 1)的张量。标签以0为基础。
示例
调用方式示例:
from torch_npu.contrib.module import LabelSmoothingCrossEntropy m = LabelSmoothingCrossEntropy(10)
使用示例:
>>> x = torch.randn(2, 10) >>> y = torch.randint(0, 10, size=(2,)) >>> x = x.npu() >>> y = y.npu() >>> x.requires_grad = True >>> m = LabelSmoothingCrossEntropy(10) >>> npu_output = m(x, y) >>> npu_output.backward()
父主题: torch_npu.contrib