(beta)torch_npu.npu_group_norm_swish
接口原型
torch_npu.npu_group_norm_swish(Tensor input, int num_groups, Tensor weight, Tensor bias, float? eps=1e-5, float? swish_scale=1.0) -> (Tensor, Tensor, Tensor)
参数说明
- “input”(Tensor) - 表示需要进行组归一化的数据,支持2-8D张量,数据类型支持float16,float32,bfloat16。
- “num_groups”(Int) - 表示将input的第1维分为num_groups组,input的第1维必须能被num_groups整除。
- “weight”(Tensor) - 支持1D张量,并且第0维大小与input的第1维相同;数据类型支持float16,float32,bfloat16,并且需要与input一致。
- “bias”(Tensor) - 支持1D张量,并且第0维大小与input的第1维相同;数据类型支持float16,float32,bfloat16,并且需要与input一致。
- “eps”(Float) - 计算组归一化时加到分母上的值,以保证数值的稳定性。默认值为1e-5。
- “swish_scale”(Float) - 用于进行swish计算的值。默认值为1.0。
输出说明
out(Tensor) - 表示组归一化和swish计算的结果。
mean(Tensor) - 表示分组后的均值。
rstd(Tensor) - 表示分组后的标准差的倒数。
约束说明
需要计算反向梯度场景时,“input”的第1维除以“num_groups”的结果不能超过4000,“input”、“weight”、“bias”参数不支持含有-inf、inf或nan值。
支持的型号
Atlas A2 训练系列产品
Atlas A3 训练系列产品
调用示例
1 2 3 4 5 6 7 8 9 10 | import torch import torch_npu input = torch.randn(3, 4, 6, dtype=torch.float32).npu() weight = torch.randn(input.size(1), dtype=torch.float32).npu() bias = torch.randn(input.size(1), dtype=torch.float32).npu() num_groups = input.size(1) eps = 1e-5 swish_scale = 1.0 out, mean, rstd = torch_npu.npu_group_norm_swish(input, num_groups, weight, bias, eps=eps, swish_scale=swish_scale) |
父主题: torch_npu