昇腾社区首页
EN
注册

torch_npu.npu_group_norm_silu

功能描述

计算输入input的组归一化结果out、均值meanOut、标准差的倒数rstdOut、以及silu的输出。

接口原型

1
torch_npu.npu_group_norm_silu(Tensor input, Tensor weight, Tensor bias, int group, float eps) -> (Tensor, Tensor, Tensor)

参数说明

  • input:Tensor类型,必选输入,源数据张量,维度需大于一维且小于等于八维,数据格式支持ND,支持非连续的Tensor。
    • Atlas 推理系列产品:数据类型支持float16、float。
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:数据类型支持float16、float、bfloat16。
  • weight:Tensor类型,可选输入,索引张量,维度为1且元素数量需与输入input的第1维度保持相同,数据格式支持ND,支持非连续的Tensor。
    • Atlas 推理系列产品:数据类型支持float16、float。
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:数据类型支持float16、float、bfloat16。
  • bias:Tensor类型,可选输入,更新数据张量,维度为1元素数量需与输入input的第1维度保持相同,数据格式支持ND,支持非连续的Tensor。
    • Atlas 推理系列产品:数据类型支持float16、float。
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:数据类型支持float16、float、bfloat16。
  • group:int类型,必选输入,表示将输入“input”的第1维度分为“group”组,“group”需大于0。
  • eps:float类型,可选参数,为保证数值稳定性而加到分母上的值,“eps”需大于0。

输出说明

  • out:Tensor类型,数据类型和shape与input相同,支持ND,支持非连续的Tensor。
    • Atlas 推理系列产品:数据类型支持float16、float。
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:数据类型支持float16、float、bfloat16。
  • meanOut:Tensor类型,数据类型与input相同,shape为(N, group)支持ND,支持非连续的Tensor。
    • Atlas 推理系列产品:数据类型支持float16、float。
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:数据类型支持float16、float、bfloat16。
  • rstdOut:Tensor类型,数据类型与input相同,shape为(N, group)。
    • Atlas 推理系列产品:数据类型支持float16、float。
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:数据类型支持float16、float、bfloat16。

约束说明

  • input、weight、bias、out、meanOut、rstdOut数据类型必须支持的范围之内。
  • out、meanOut、rstdOut的数据类型与input相同;weight、bias与input可以不同。
  • weight与bias的数据类型必须保持一致,且数据类型的精度不能低于input的数据类型。
  • weight与bias的维度需为1且元素数量需与输入input的第1维度保持相同。
  • input维度需大于一维且小于等于八维,且input第1维度能整除group。
  • input任意维都需大于0。
  • out的shape与input相同。
  • meanOut与rstdOut的shape为(N, group),其中N为input第0维度值。
  • eps需大于0。
  • group需大于0。

支持的型号

  • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件
  • Atlas 推理系列产品

调用示例

单算子调用:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
import numpy as np
import torch_npu
 
dtype = np.float32
shape_x = [24,320,48,48]
num_groups = 32
shape_c = [320]
eps = 0.00001
 
input_npu=torch.randn(shape_x,dtype=torch.float32).npu()
weight_npu=torch.randn(shape_c,dtype=torch.float32).npu()
bias_npu=torch.randn(shape_c,dtype=torch.float32).npu()
out_npu, mean_npu, rstd_out = torch_npu.npu_group_norm_silu(input_npu, weight_npu, bias_npu, group=num_groups, eps=eps)
 
 
input_npu=torch.randn(shape_x,dtype=torch.bfloat16).npu()
weight_npu=torch.randn(shape_c,dtype=torch.bfloat16).npu()
bias_npu=torch.randn(shape_c,dtype=torch.bfloat16).npu()
out_npu, mean_npu, rstd_out = torch_npu.npu_group_norm_silu(input_npu, weight_npu, bias_npu, group=num_groups, eps=eps)
 
input_npu=torch.randn(shape_x,dtype=torch.float16).npu()
weight_npu=torch.randn(shape_c,dtype=torch.float16).npu()
bias_npu=torch.randn(shape_c,dtype=torch.float16).npu()
out_npu, mean_npu, rstd_out = torch_npu.npu_group_norm_silu(input_npu, weight_npu, bias_npu, group=num_groups, eps=eps)