昇腾社区首页
中文
注册
开发者
下载

(beta)torch_npu.npu_swiglu

产品支持情况

产品 是否支持
[object Object]Atlas A2 训练系列产品[object Object]

功能说明

  • API功能:Swish门控线性单元激活函数,实现张量input的swiglu计算。

  • 计算公式:

    公式中xx是输入参数input的Tensor。dimdim是切分维度,默认为-1。AABBinput沿dim维度切分的Tensor。A表示前半部分张量,B表示后半部分张量。

    outputs=swiglu(x,dim=1)=swish(A)B=Asigmoid(A)Boutputs=swiglu(x,dim=-1)=swish(A)*B=A*sigmoid(A)*B

函数原型

[object Object]

参数说明

input (Tensor):必选参数,表示待计算的数据,对应公式中的xx。shape支持1-8维,且shape必须在入参dim对应维度上可以整除2。不支持非连续的Tensor,不支持空Tensor。数据类型支持float32float16bfloat16类型。

dim (int):可选参数,默认为-1。需要进行切分的维度序号,对input相应轴进行对半切。取值范围为[-input.dim(), input.dim()-1]。

返回值说明

Tensor

对应公式中的outputsoutputs。数据类型与计算输入input的类型一致,不支持非连续的Tensor。

调用示例

[object Object]