昇腾社区首页
中文
注册

class NpuCachedDropout()

API接口

class NpuCachedDropout(torch.nn.Dropout):

功能描述

在NPU设备上使用FairseqDropout。

参数说明

  • p (Float) - 元素归零的概率。
  • module_name (String) - 模型名称。

示例

调用方式示例:
from torch_npu.contrib.module import NpuDropPath
NpuCachedDropout.enable_dropout_ensemble(model)
使用示例:
  >>> model = NpuMNIST().to("npu")
  >>> x = torch.randn(2,10,16,16).to("npu")
  >>> NpuCachedDropout.enable_dropout_ensemble(model)
  >>> output = model(x)