RmsNormBackwardOperation
产品支持情况
产品 |
是否支持 |
---|---|
√ |
|
√ |
|
x |
|
x |
|
x |
功能说明
rmsnorm的反向计算。
定义
1 2 3 | struct RmsNormBackwardParam { uint8_t rsv[8] = {0}; }; |
参数列表
成员名称 |
类型 |
默认值 |
描述 |
---|---|---|---|
rsv[8] |
uint8_t |
{0} |
预留参数。 |
输入
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
dy |
[dim_0, dim_1, ..., dim_n] |
float16/float/bf16 |
ND |
输入梯度。数据类型、维度与x一致。 |
x |
[dim_0, dim_1, ..., dim_n] |
float16/float/bf16 |
ND |
正向计算输入。 |
rstd |
[dim_0, dim_1, ..., dim_n] |
float |
ND |
正向计算中间结果。 |
gamma |
[dim_0, dim_1, ..., dim_n] |
float16/float/bf16 |
ND |
数据类型与x一致。维度数需要大于0,并小于x的维度数,gamma的维度从最后一维向前,每一维都需要和x保持一致。 |
输出
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
dx |
[dim_0, dim_1, ..., dim_n] |
float16/float/bf16 |
ND |
正向输入x的梯度。数据类型、维度与x一致。 |
dgamma |
[dim_0, dim_1, ..., dim_n] |
float |
ND |
正向输入gamma的梯度。维度与gamma一致。 |
约束说明
无
调用示例
前置条件和编译命令请参见算子调用示例。
场景:基础场景。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 | #include <iostream> #include <vector> #include <numeric> #include "acl/acl.h" #include "atb/operation.h" #include "atb/types.h" #include "atb/train_op_params.h" #include "demo_util.h" const int32_t DEVICE_ID = 0; const uint32_t DIM_0 = 32; const uint32_t DIM_1 = 64; const uint32_t DIM_2 = 128; /** * @brief 准备atb::VariantPack中的所有输入tensor * @param contextPtr context指针 * @param stream stream * @return atb::SVector<atb::Tensor> atb::VariantPack中的输入tensor * @note 需要传入所有host侧tensor */ atb::SVector<atb::Tensor> PrepareInTensor(atb::Context *contextPtr, aclrtStream stream) { // 创建shape为[32, 64, 128]的tensor atb::Tensor dy = CreateTensorFromVector(contextPtr, stream, std::vector<float>(DIM_0 * DIM_1 * DIM_2, 2.0), ACL_FLOAT, aclFormat::ACL_FORMAT_ND, {DIM_0, DIM_1, DIM_2}); atb::Tensor x = CreateTensorFromVector(contextPtr, stream, std::vector<float>(DIM_0 * DIM_1 * DIM_2, 2.0), ACL_FLOAT, aclFormat::ACL_FORMAT_ND, {DIM_0, DIM_1, DIM_2}); atb::Tensor rstd = CreateTensorFromVector(contextPtr, stream, std::vector<float>(DIM_0 * DIM_1, 2.0), ACL_FLOAT, aclFormat::ACL_FORMAT_ND, {DIM_0, DIM_1, 1}); atb::Tensor gamma = CreateTensorFromVector( contextPtr, stream, std::vector<float>(DIM_2, 2.0), ACL_FLOAT, aclFormat::ACL_FORMAT_ND, {DIM_2}); atb::SVector<atb::Tensor> inTensors = {dy, x, rstd, gamma}; return inTensors; } /** * @brief 创建一个rmsnorm backward operation * @return atb::Operation * 返回一个Operation指针 */ atb::Operation *CreateRmsNormBackwardOperation() { atb::train::RmsNormBackwardParam param; atb::Operation *rmsNormBackwardOp = nullptr; CHECK_STATUS(atb::CreateOperation(param, &rmsNormBackwardOp)); return rmsNormBackwardOp; } int main(int argc, char **argv) { // 设置卡号、创建context、设置stream atb::Context *context = nullptr; void *stream = nullptr; CHECK_STATUS(aclInit(nullptr)); CHECK_STATUS(aclrtSetDevice(DEVICE_ID)); CHECK_STATUS(atb::CreateContext(&context)); CHECK_STATUS(aclrtCreateStream(&stream)); context->SetExecuteStream(stream); // 创建op atb::Operation *rmsnormBackwardOp = CreateRmsNormBackwardOperation(); // 准备输入tensor atb::VariantPack variantPack; variantPack.inTensors = PrepareInTensor(context, stream); // 放入输入tensor atb::Tensor dx = CreateTensor(ACL_FLOAT, aclFormat::ACL_FORMAT_ND, {DIM_0, DIM_1, DIM_2}); atb::Tensor dgamma = CreateTensor(ACL_FLOAT, aclFormat::ACL_FORMAT_ND, {DIM_2}); variantPack.outTensors = {dx, dgamma}; // 放入输出tensor uint64_t workspaceSize = 0; // 计算workspace大小 CHECK_STATUS(rmsnormBackwardOp->Setup(variantPack, workspaceSize, context)); uint8_t *workspacePtr = nullptr; if (workspaceSize > 0) { CHECK_STATUS(aclrtMalloc((void **)(&workspacePtr), workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST)); } // rmsnorm执行 rmsnormBackwardOp->Execute(variantPack, workspacePtr, workspaceSize, context); CHECK_STATUS(aclrtSynchronizeStream(stream)); // 流同步,等待device侧任务计算完成 // 释放资源 for (atb::Tensor &inTensor : variantPack.inTensors) { CHECK_STATUS(aclrtFree(inTensor.deviceData)); } for (atb::Tensor &outTensor : variantPack.outTensors) { CHECK_STATUS(aclrtFree(outTensor.deviceData)); } if (workspaceSize > 0) { CHECK_STATUS(aclrtFree(workspacePtr)); } CHECK_STATUS(atb::DestroyOperation(rmsnormBackwardOp)); // operation,对象概念,先释放 CHECK_STATUS(aclrtDestroyStream(stream)); CHECK_STATUS(DestroyContext(context)); // context,全局资源,后释放 CHECK_STATUS(aclFinalize()); std::cout << "Rmsnorm backward demo success!" << std::endl; return 0; } |