aclnnRmsNormGrad
产品支持情况
产品 | 是否支持 |
---|---|
[object Object]Atlas A3 训练系列产品/Atlas A3 推理系列产品[object Object] | √ |
[object Object]Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件[object Object] | √ |
[object Object]Atlas 200I/500 A2 推理产品[object Object] | × |
[object Object]Atlas 推理系列产品 [object Object] | √ |
[object Object]Atlas 训练系列产品[object Object] | × |
功能说明
- 算子功能:undefined的反向计算。用于计算RMSNorm的梯度,即在反向传播过程中计算输入张量的梯度。
算子公式:
- 正向公式:
- 反向推导:
函数原型
每个算子分为undefined,必须先调用aclnnRmsNormGradGetWorkspaceSize
接口获取计算所需workspace大小以及包含了算子计算流程的执行器,再调用aclnnRmsNormGrad
接口执行计算。
aclnnStatus aclnnRmsNormGradGetWorkspaceSize( const aclTensor *dy, const aclTensor *x, const aclTensor *rstd, const aclTensor *gamma, const aclTensor *dxOut, const aclTensor *dgammaOut, uint64_t *workspaceSize, aclOpExecutor **executor)
aclnnStatus aclnnRmsNormGrad( void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream)
aclnnRmsNormGradGetWorkspaceSize
参数说明:
- dy(aclTensor*,计算输入):Device侧的aclTensor,表示反向传回的梯度。对应公式中的
dy
。undefined支持ND,shape支持1-8维度。- [object Object]Atlas 推理系列产品[object Object]:数据类型支持FLOAT32,FLOAT16。
- [object Object]Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件[object Object]、[object Object]Atlas A3 训练系列产品/Atlas A3 推理系列产品[object Object]:数据类型支持FLOAT32,FLOAT16,BFLOAT16。
- x(aclTensor*,计算输入):Device侧的aclTensor,正向算子的输入,表示被标准化的数据。对应公式中的
x
。undefined支持ND,shape支持1-8维度,且与入参dy
的shape一致。- [object Object]Atlas 推理系列产品[object Object]:数据类型支持FLOAT32,FLOAT16。
- [object Object]Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件[object Object]、[object Object]Atlas A3 训练系列产品/Atlas A3 推理系列产品[object Object]:数据类型支持FLOAT32,FLOAT16,BFLOAT16。
- rstd(aclTensor*,计算输入):Device侧的aclTensor,正向算子的中间计算结果。对应公式中的
Rms(x)
。数据类型支持FLOAT32。undefined支持ND,shape支持1-8维度,shape需要满足rstd_shape = x_shape[0:n],n < x_shape.dims(),n与gamma一致。 - gamma(aclTensor*,计算输入):Device侧的aclTensor,正向算子的输入。对应公式中的
g
。undefined支持ND,shape支持1-8维度,shape需要满足gamma_shape = x_shape[n:], n < x_shape.dims()。- [object Object]Atlas 推理系列产品[object Object]:数据类型支持FLOAT32,FLOAT16。
- [object Object]Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件[object Object]、[object Object]Atlas A3 训练系列产品/Atlas A3 推理系列产品[object Object]:数据类型支持FLOAT32,FLOAT16,BFLOAT16。
- dxOut(aclTensor*,计算输出):Device侧的aclTensor,表示输入
x
的梯度。对应公式中的dx
。undefined支持ND,shape支持1-8维度,shape与入参dy
的shape保持一致。- [object Object]Atlas 推理系列产品[object Object]:数据类型支持FLOAT32,FLOAT16。
- [object Object]Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件[object Object]、[object Object]Atlas A3 训练系列产品/Atlas A3 推理系列产品[object Object]:数据类型支持FLOAT32,FLOAT16,BFLOAT16。
- dgammaOut(aclTensor*,计算输出):Device侧的aclTensor,表示
gamma
的梯度。对应公式中的dg
。数据类型支持FLOAT32。undefined支持ND,shape支持1-8维度,shape与入参gamma
的shape保持一致。 - workspaceSize(uint64_t*,出参):返回用户需要在Device侧申请的workspace大小。
- executor(aclOpExecutor**,出参):返回op执行器,包含了算子计算流程。
- dy(aclTensor*,计算输入):Device侧的aclTensor,表示反向传回的梯度。对应公式中的
返回值:
aclnnStatus:返回状态码,具体参见undefined。
[object Object]
aclnnRmsNormGrad
参数说明:
- workspace(void*,入参):在Device侧申请的workspace内存地址。
- workspaceSize(uint64_t,入参):在Device侧申请的workspace大小,由第一段接口aclnnRmsNormGradGetWorkspaceSize获取。
- executor(aclOpExecutor*,入参):op执行器,包含了算子计算流程。
- stream(aclrtStream,入参):指定执行任务的Stream。
返回值: aclnnStatus:返回状态码,具体参见undefined。
约束说明
- [object Object]Atlas 推理系列产品[object Object]:
x
、dy
、gamma
输入的尾轴长度必须大于等于32 Bytes。 - 支持类型说明:
- 是否支持空Tensor:支持空进空出。
- 是否支持undefined:支持非连续Tensor。
- 各产品支持数据类型说明:
- [object Object]Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件[object Object]、[object Object]Atlas A3 训练系列产品/Atlas A3 推理系列产品[object Object]:
dy
数据类型x
数据类型rstd
数据类型gamma
数据类型dxOut
数据类型dgammaOut
数据类型FLOAT16 FLOAT16 FLOAT32 FLOAT32 FLOAT16 FLOAT32 BFLOAT16 BFLOAT16 FLOAT32 FLOAT32 BFLOAT16 FLOAT32 FLOAT16 FLOAT16 FLOAT32 FLOAT16 FLOAT16 FLOAT32 FLOAT32 FLOAT32 FLOAT32 FLOAT32 FLOAT32 FLOAT32 BFLOAT16 BFLOAT16 FLOAT32 BFLOAT16 BFLOAT16 FLOAT32 - [object Object]Atlas 推理系列产品[object Object]:
dy
数据类型x
数据类型rstd
数据类型gamma
数据类型dxOut
数据类型dgammaOut
数据类型FLOAT16 FLOAT16 FLOAT32 FLOAT16 FLOAT16 FLOAT32 FLOAT32 FLOAT32 FLOAT32 FLOAT32 FLOAT32 FLOAT32
- [object Object]Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件[object Object]、[object Object]Atlas A3 训练系列产品/Atlas A3 推理系列产品[object Object]:
调用示例
示例代码如下,仅供参考,具体编译和执行过程请参考undefined。
[object Object]