昇腾社区首页
中文
注册

aclnnMlaPrologV2WeightNz

支持的产品型号

  • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件
  • Atlas A3 训练系列产品/Atlas A3 推理系列产品

函数原型

每个算子分为两段式接口,必须先调用“aclnnMlaPrologV2WeightNzGetWorkspaceSize”接口获取入参并根据流程计算所需workspace大小,再调用“aclnnMlaPrologV2WeightNz”接口执行计算。

  • aclnnStatus aclnnMlaPrologV2WeightNzGetWorkspaceSize(const aclTensor *tokenX, const aclTensor *weightDq, const aclTensor *weightUqQr, const aclTensor *weightUk, const aclTensor *weightDkvKr, const aclTensor *rmsnormGammaCq, const aclTensor *rmsnormGammaCkv, const aclTensor *ropeSin, const aclTensor *ropeCos, const aclTensor *cacheIndex, aclTensor *kvCacheRef, aclTensor *krCacheRef, const aclTensor *dequantScaleXOptional, const aclTensor *dequantScaleWDqOptional, const aclTensor *dequantScaleWUqQrOptional, const aclTensor *dequantScaleWDkvKrOptional, const aclTensor *quantScaleCkvOptional, const aclTensor *quantScaleCkrOptional, const aclTensor *smoothScalesCqOptional, double rmsnormEpsilonCq, double rmsnormEpsilonCkv, char *cacheModeOptional, const aclTensor *queryOut, const aclTensor *queryRopeOut, const aclTensor *dequantScaleQNopeOutOptional, uint64_t *workspaceSize, aclOpExecutor **executor)
  • aclnnStatus aclnnMlaPrologV2WeightNz(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream)

功能说明

  • 算子功能:推理场景,Multi-Head Latent Attention前处理的计算。主要计算过程分为五路;

    • 首先对输入xx乘以WDQW^{DQ}进行下采样和RmsNorm后分为两路,第一路乘以WUQW^{UQ}WUKW^{UK}经过两次上采样后得到qNq^N;第二路乘以WQRW^{QR}后经过旋转位置编码(ROPE)得到qRq^R
    • 第三路是输入xx乘以WDKVW^{DKV}进行下采样和RmsNorm后传入Cache中得到kCk^C
    • 第四路是输入xx乘以WKRW^{KR}后经过旋转位置编码后传入另一个Cache中得到kRk^R
    • 第五路是输出qNq^N经过DynamicQuant后得到的量化参数。
    • 权重参数WeightDq、WeightUqQr和WeightDkvKr需要以NZ格式传入
  • 计算公式

    RmsNorm公式

    RmsNorm(x)=γxiRMS(x)\text{RmsNorm}(x) = \gamma \cdot \frac{x_i}{\text{RMS}(x)} RMS(x)=1Ni=1Nxi2+ϵ\text{RMS}(x) = \sqrt{\frac{1}{N} \sum_{i=1}^{N} x_i^2 + \epsilon}

    Query在计算公式,包括下采样,RmsNorm和两次上采样

    cQ=RmsNorm(xWDQ)c^Q = RmsNorm(x \cdot W^{DQ}) qC=cQWUQq^C = c^Q \cdot W^{UQ} qN=qCWUKq^N = q^C \cdot W^{UK}

    对Query的进行ROPE旋转位置编码

    qR=ROPE(cQWQR)q^R = ROPE(c^Q \cdot W^{QR})

    Key计算公式,包括下采样和RmsNorm,将计算结果存入cache

    cKV=RmsNorm(xWDKV)c^{KV} = RmsNorm(x \cdot W^{DKV}) kC=Cache(cKV)k^C = Cache(c^{KV})

    对Key进行ROPE旋转位置编码,并将结果存入cache

    kR=Cache(ROPE(xWKR))k^R = Cache(ROPE(x \cdot W^{KR}))

    Dequant Scale Query Nope 计算公式:

    dequantScaleQNope=RowMax(abs(qN))/127dequantScaleQNope = {RowMax(abs(q^{N})) / 127} qN=round(qN/dequantScaleQNope)q^{N} = {round(q^{N} / dequantScaleQNope)}

aclnnMlaPrologV2WeightNzGetWorkspaceSize

  • 参数说明:

    • tokenX(aclTensor*,计算输入):表示输入的tensor,用于计算Query和Key的x,Device侧的aclTensor。shape支持2维和3维,格式为(T,He)和(B,S,He),dtype支持BFLOAT16和INT8,数据格式支持ND格式。

    • weightDq(aclTensor*,计算输入):表示用于计算Query的下采样权重矩阵,对应公式中的WDQW^{DQ},Device侧的aclTensor。其shape支持2维,格式为(He,Hcq),dtype支持BFLOAT16和INT8,数据格式支持FRACTAL_NZ格式。

    • weightUqQr(aclTensor*,计算输入):表示用于计算Query的上采样权重矩阵和Query的位置编码权重矩阵,对应公式中的WUQW^{UQ}WQRW^{QR},Device侧的aclTensor。其shape支持2维,格式为(Hcq,N*(D+Dr)),dtype支持BFLOAT16和INT8,数据格式支持FRACTAL_NZ格式。

    • weightUk(aclTensor*,计算输入):表示用于计算Key的上采样权重,对应公式中的WUKW^{UK},Device侧的aclTensor。其shape支持3维,格式为(N,D,Hckv),dtype支持BFLOAT16,数据格式支持ND格式

    • weightDkvKr(aclTensor*,计算输入):表示用于计算Key的下采样权重矩阵和Key的位置编码权重矩阵,对应公式中的WDKVW^{DKV}WKRW^{KR},Device侧的aclTensor。其shape支持2维,格式为(He,Hckv+Dr),dtype支持BFLOAT16和INT8,数据格式支持FRACTAL_NZ格式。

    • rmsnormGammaCq(aclTensor*,计算输入):表示计算cQc^Q的RmsNorm公式中的γ\gamma参数,Device侧的aclTensor。其shape支持1维,格式为(Hcq),dtype支持BFLOAT16,数据格式支持ND格式。

    • rmsnormGammaCkv(aclTensor*,计算输入):表示计算cKVc^{KV}的RmsNorm公式中的γ\gamma参数,Device侧的aclTensor。其shape支持1维,格式为(Hckv),dtype支持BFLOAT16,数据格式支持ND格式。

    • ropeSin(aclTensor*,计算输入):表示用于计算旋转位置编码的正弦参数矩阵,Device侧的aclTensor。其shape支持2维和3维,格式为(T,Dr)和(B,S,Dr),dtype支持BFLOAT16,数据格式支持ND格式。

    • ropeCos(aclTensor*,计算输入):表示用于计算旋转位置编码的余弦参数矩阵,Device侧的aclTensor。其shape支持2维和3维,格式为(T,Dr)和(B,S,Dr),dtype支持BFLOAT16,数据格式支持ND格式。

    • cacheIndex(aclTensor*,计算输入):表示用于存储kvCache和krCache的索引,Device侧的aclTensor。其shape支持1维和2维,格式为(T)和(B,S),dtype支持INT64,数据格式支持ND格式。

    • kvCacheRef(aclTensor*,计算输入):表示用于cache索引的aclTensor。其shape支持4维,格式为(BlockNum,BlockSize,Nkv,Hckv),dtype支持BFLOAT16和INT8,数据格式支持ND格式。计算结果原地更新,更新后结果对应公式中的kCk^C

    • krCacheRef(aclTensor*,计算输入):表示用于key位置编码的cache,Device侧的aclTensor。其shape支持4维,格式为(BlockNum,BlockSize,Nkv,Dr),dtype支持BFLOAT16和INT8,数据格式支持ND格式。计算结果原地更新,更新后结果对应公式中的kRk^R

    • dequantScaleXOptional(aclTensor*,计算输入):表示用于当输入tokenX为INT8类型时,下采样后进行反量化操作时的参数,tokenX量化方式为pertoken。其shape支持2维,格式为(T, 1)和(B*S, 1),dtype支持FLOAT,数据格式支持ND格式。

    • dequantScaleWDqOptional(aclTensor*,计算输入):表示用于当tokenX输入为INT8类型时,下采样后进行反量化操作时的参数,tokenX量化方式为perchannel。其shape支持2维,格式为(1, N*(D+Dr)),dtype支持FLOAT,数据格式支持ND格式。

    • dequantScaleWUqQrOptional(aclTensor*,计算输入):表示用于对MatmulQcQr矩阵乘后进行反量化操作时的参数,量化参数为per-channel,Device侧的aclTensor。其shape支持2维,格式为(1,N*(D+Dr)),dtype支持FLOAT,数据格式支持ND格式。

    • dequantScaleWDkvKrOptional(aclTensor*,计算输入):表示用于对MatmulQcQr矩阵乘后进行反量化操作时的参数,量化参数为per-channel,Device侧的aclTensor。其shape支持2维,格式为(1,Hckv+Dr),dtype支持FLOAT,数据格式支持ND格式。

    • quantScaleCkvOptional(aclTensor*,计算输入):用于对输出到KVCache中的数据做量化操作时的参数,Device侧的aclTensor。其shape支持2维,格式为(1,Hckv),dtype支持FLOAT,数据格式支持ND格式。

    • quantScaleCkrOptional(aclTensor*,计算输入):用于对输出到KRCache中的数据做量化操作时的参数,Device侧的aclTensor。其shape支持2维,格式为(1,Dr),dtype支持FLOAT,数据格式支持ND格式。

    • smoothScalesCqOptional(aclTensor*,计算输入):用于对RmsNormCq输出做动态量化操作时的参数,Device侧的aclTensor。其shape支持2维,格式为(1,Hcq),dtype支持FLOAT,数据格式支持ND格式。

    • rmsnormEpsilonCq(double,计算输入):表示计算cQc^Q的RmsNorm公式中的ϵ\epsilon参数,用户不特意指定时建议传入1e-05。

    • rmsnormEpsilonCkv(double,计算输入):表示计算cKVc^{KV}的RmsNorm公式中的ϵ\epsilon参数,用户不特意指定时建议传入1e-05。

    • cacheModeOptional(char*,计算输入):表示kvCache的模式,支持"PA_BSND","PA_NZ",用户不特意指定时建议传入"PA_BSND"。

    • queryOut(aclTensor*,计算输出):表示Query的输出tensor,对应公式中的qNq^N,Device侧的aclTensor。shape支持3维和4维,格式为(T,N,Hckv)和(B,S,N,Hckv),dtype支持BFLOAT16,数据格式支持ND格式。

    • queryRopeOut(aclTensor*,计算输出):表示Query位置编码的输出tensor,对应公式中的qRq^R,Device侧的aclTensor。shape支持3维和4维,格式为(T,N,Dr)和(B,S,N,Dr),dtype支持BFLOAT16,数据格式支持ND格式。

    • dequantScaleQNopeOutOptional(aclTensor*,计算输出):表示Query的输出tensor的量化参数。shape支持3维,全量化kv_cache量化场景下,其shape为(B*S, N, 1)和(T, N, 1),dtype支持FLOAT,数据格式支持ND格式;其他场景下,不修改此出参,所以不做任何约束且允许输入空指针,返回与输入一致。

    • workspaceSize(uint64_t*,计算输出):返回需要在Device侧申请的workspace大小。

    • executor(aclOpExecutor**,计算输出):返回op执行器,包含了算子计算流程。

  • 返回值:

    aclnnStatus:返回状态码,具体参见aclnn返回码

    第一段接口完成入参校验,若出现以下错误码,则对应原因为:
    - 返回161001(ACLNN_ERR_PARAM_NULLPTR):必须传入的参数中存在空指针。
    - 返回161002(ACLNN_ERR_PARAM_INVALID):输入参数的shape、dtype和数据类型不在支持的范围之内。
    - 返回361001(ACLNN_ERR_RUNTIME_ERROR):API内存调用npu runtime的接口异常。
    - 返回561002 (ACLNN_ERR_INNER_TILING_ERROR) : tiling发生异常,入参的dtype类型或者shape错误。
  • 约束说明

    • shape格式字段含义
        B:Batch 表示输入样本批量大小,取值范围:1~65536
        S:Seq-Length 表示输入样本序列长度,取值范围:1~16
        He:Head-Size 表示隐藏层的大小,取值固定为:7168
        Hcq:q低秩矩阵维度,取值固定为:1536
        N:Head-Num 表示多头数,取值范围:8、16、32、64、128
        Hckv:kv低秩矩阵维度,取值固定为:512
        D:qk不含位置编码维度,取值固定为:128
        Dr:qk位置编码维度,取值固定为:64
        Nkv:kv的head数,取值固定为:1
        BlockNum:PagedAttention场景下的块数,取值为计算B*Skv/BlockSize的值后再向上取整,其中Skv表示kv的序列长度
        BlockSize:PagedAttention场景下的块大小,取值范围:16、128
        T:BS合轴后的大小,取值范围:1~1048576
    • shape约束
      若tokenX的维度采用BS合轴,即(T, He)
        ropeSin和ropeCos的shape为(T, Dr)
        cacheIndex的shape为(T,)
        dequantScaleXOptional的shape为(T, 1)
        queryOut的shape为(T, N, Hckv)
        queryRopeOut的shape为(T, N, Dr)
        全量化场景下,dequantScaleQNopeOutOptional的shape为(T, N, 1),其他场景下为(1)
    
      若tokenX的维度不采用BS合轴,即(B, S, He)
        ropeSin和ropeCos的shape为(B, S, Dr)
        cacheIndex的shape为(B, S)
        dequantScaleXOptional的shape为(B*S, 1)
        queryOut的shape为(B, S, N, Hckv)
        queryRopeOut的shape为(B, S, N, Dr)
        全量化场景下,dequantScaleQNopeOutOptional的shape为(B*S, N, 1),其他场景下为(1)
    • 当前aclnnMlaPrologV2WeightNz接口支持以下场景:
      场景 含义
      非量化 入参:所有入参皆为非量化数据
      出参:所有出参皆为非量化数据
      部分量化 kv_cache非量化 入参:weightUqQr传入pertoken量化数据,其余入参皆为非量化数据
      出参:所有出参返回非量化数据
      kv_cache量化 入参:weightUqQr传入pertoken量化数据,kvCacheRef、krCacheRef传入perchannel量化数据,其余入参皆为非量化数据
      出参:kvCacheRef、krCacheRef返回perchannel量化数据,其余出参返回非量化数据
      全量化 kv_cache非量化 入参:tokenX传入pertoken量化数据,weightDq、weightUqQr、weightDkvKr传入perchannel量化数据,其余入参皆为非量化数据
      出参:所有出参皆为非量化数据
      kv_cache量化 入参:tokenX传入pertoken量化数据,weightDq、weightUqQr、weightDkvKr传入perchannel量化数据,kvCacheRef传入pertensor量化数据,其余入参皆为非量化数据
      出参:queryOut返回pertoken_head量化数据,kvCacheRef出参返回pertensor量化数据,其余出参范围非量化数据

      在不同量化场景下,参数的dtype和shape组合需要满足如下条件:
      参数名 非量化场景 部分量化场景 全量化场景
      kv_cache非量化 kv_cache量化 kv_cache非量化 kv_cache量化
      dtype shape dtype shape dtype shape dtype shape dtype shape
      tokenX BFLOAT16 · (B,S,He)
      · (T, He)
      BFLOAT16 · (B,S,He)
      · (T, He)
      BFLOAT16 · (B,S,He)
      · (T, He)
      INT8 · (B,S,He)
      · (T, He)
      INT8 · (B,S,He)
      · (T, He)
      weightDq BFLOAT16 (He, Hcq) BFLOAT16 (He, Hcq) BFLOAT16 (He, Hcq) INT8 (He, Hcq) INT8 (He, Hcq)
      weightUqQr BFLOAT16 (Hcq, N*(D+Dr)) INT8 (Hcq, N*(D+Dr)) INT8 (Hcq, N*(D+Dr)) INT8 (Hcq, N*(D+Dr)) INT8 (Hcq, N*(D+Dr))
      weightUk BFLOAT16 (N, D, Hckv) BFLOAT16 (N, D, Hckv) BFLOAT16 (N, D, Hckv) BFLOAT16 (N, D, Hckv) BFLOAT16 (N, D, Hckv)
      weightDkvKr BFLOAT16 (He, Hckv+Dr) BFLOAT16 (He, Hckv+Dr) BFLOAT16 (He, Hckv+Dr) INT8 (He, Hckv+Dr) INT8 (He, Hckv+Dr)
      rmsnormGammaCq BFLOAT16 (Hcq) BFLOAT16 (Hcq) BFLOAT16 (Hcq) BFLOAT16 (Hcq) BFLOAT16 (Hcq)
      rmsnormGammaCkv BFLOAT16 (Hckv) BFLOAT16 (Hckv) BFLOAT16 (Hckv) BFLOAT16 (Hckv) BFLOAT16 (Hckv)
      ropeSin BFLOAT16 · (B,S,Dr)
      · (T, Dr )
      BFLOAT16 · (B,S,Dr)
      · (T, Dr )
      BFLOAT16 · (B,S,Dr)
      · (T, Dr )
      BFLOAT16 · (B,S,Dr)
      · (T, Dr )
      BFLOAT16 · (B,S,Dr)
      · (T, Dr )
      ropeCos BFLOAT16 · (B,S,Dr)
      · (T, Dr )
      BFLOAT16 · (B,S,Dr)
      · (T, Dr )
      BFLOAT16 · (B,S,Dr)
      · (T, Dr )
      BFLOAT16 · (B,S,Dr)
      · (T, Dr )
      BFLOAT16 · (B,S,Dr)
      · (T, Dr )
      cacheIndex INT64 · (B,S)
      · (T)
      BFLOAT16 · (B,S)
      · (T)
      BFLOAT16 · (B,S)
      · (T)
      BFLOAT16 · (B,S)
      · (T)
      BFLOAT16 · (B,S)
      · (T)
      kvCacheRef BFLOAT16 (BlockNum, BlockSize, Nkv, Hckv) BFLOAT16 (BlockNum, BlockSize, Nkv, Hckv) INT8 (BlockNum, BlockSize, Nkv, Hckv) BFLOAT16 (BlockNum, BlockSize, Nkv, Hckv) INT8 (BlockNum, BlockSize, Nkv, Hckv)
      krCacheRef BFLOAT16 (BlockNum, BlockSize, Nkv, Dr) BFLOAT16 (BlockNum, BlockSize, Nkv, Dr) INT8 (BlockNum, BlockSize, Nkv, Dr) BFLOAT16 (BlockNum, BlockSize, Nkv, Dr) BFLOAT16 (BlockNum, BlockSize, Nkv, Dr)
      dequantScaleXOptional 无需赋值 / 无需赋值 / 无需赋值 / FLOAT · (B*S, 1)
      · (T, 1)
      FLOAT · (B*S, 1)
      · (T, 1)
      dequantScaleWDqOptional 无需赋值 / 无需赋值 / 无需赋值 / FLOAT (1, Hcq) FLOAT (1, Hcq)
      dequantScaleWUqQrOptional 无需赋值 / FLOAT (1, N*(D+Dr)) FLOAT (1, N*(D+Dr)) FLOAT (1, N*(D+Dr)) FLOAT (1, N*(D+Dr))
      dequantScaleWDkvKrOptional 无需赋值 / 无需赋值 / 无需赋值 / FLOAT (1, Hckv+Dr) FLOAT (1, Hckv+Dr)
      quantScaleCkvOptional 无需赋值 / 无需赋值 / FLOAT (1, Hckv) 无需赋值 / FLOAT (1, Hckv)
      quantScaleCkrOptional 无需赋值 / 无需赋值 / FLOAT (1, Dr) 无需赋值 / 无需赋值 /
      smoothScalesCqOptional 无需赋值 / FLOAT (1, Hcq) FLOAT (1, Hcq) FLOAT (1, Hcq) FLOAT (1, Hcq)
      queryOut BFLOAT16 · (B, S, N, Hckv)
      · (T, N, Hckv)
      BFLOAT16 · (B, S, N, Hckv)
      · (T, N, Hckv)
      BFLOAT16 · (B, S, N, Hckv)
      · (T, N, Hckv)
      BFLOAT16 · (B, S, N, Hckv)
      · (T, N, Hckv)
      INT8 · (B, S, N, Hckv)
      · (T, N, Hckv)
      queryRopeOut BFLOAT16 · (B, S, N, Dr)
      · (T, N, Dr)
      BFLOAT16 · (B, S, N, Dr)
      · (T, N, Dr)
      BFLOAT16 · (B, S, N, Dr)
      · (T, N, Dr)
      BFLOAT16 · (B, S, N, Dr)
      · (T, N, Dr)
      BFLOAT16 · (B, S, N, Dr)
      · (T, N, Dr)
      dequantScaleQNopeOutOptional 无需赋值 / 无需赋值 / 无需赋值 / 无需赋值 / FLOAT · (B*S, N, 1)
      · (T, N, 1)

aclnnMlaPrologV2WeightNz

  • 参数说明:

    • workspace(void*,入参):在Device侧申请的workspace内存地址。
    • workspaceSize(uint64_t,入参):在Device侧申请的workspace大小,由第一段接口aclnnMlaPrologV2WeightNzGetWorkspaceSize获取。
    • executor(aclOpExecutor*,入参):op执行器,包含了算子计算流程。
    • stream(aclrtStream,入参):指定执行任务的AscendCL Stream流。
  • 返回值:

    aclnnStatus:返回状态码,具体参见aclnn返回码

约束说明

  • 该接口与PyTorch配合使用时,需要保证CANN相关包与PyTorch相关包的版本匹配。
  • 入参不支持传入空Tensor。

调用示例

示例代码如下,仅供参考,具体编译和执行过程请参考编译与运行样例

  #include <iostream>
  #include <vector>
  #include <cstdint>
  #include "acl/acl.h"
  #include "aclnnop/aclnn_mla_prolog_v2_weight_nz.h"

  #define CHECK_RET(cond, return_expr) \
    do {                               \
      if (!(cond)) {                   \
        return_expr;                   \
      }                                \
    } while (0)

  #define LOG_PRINT(message, ...)     \
  do {                                \
    printf(message, ##__VA_ARGS__);   \
  } while (0)

  int64_t GetShapeSize(const std::vector<int64_t>& shape) {
      int64_t shape_size = 1;
      for (auto i : shape) {
          shape_size *= i;
      }
      return shape_size;
  }

  int Init(int32_t deviceId, aclrtStream* stream) {
      // 固定写法,AscendCL初始化
      auto ret = aclInit(nullptr);
      CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclInit  failed. ERROR: %d\n", ret); return ret);
      ret = aclrtSetDevice(deviceId);
      CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT ("aclrtSetDevice failed. ERROR: %d\n", ret); return  ret);
      ret = aclrtCreateStream(stream);
      CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT ("aclrtCreateStream failed. ERROR: %d\n", ret);  return ret);
      return 0;
  }

  template <typename T>
  int CreateAclTensorND(const std::vector<T>& shape,  void** deviceAddr, void** hostAddr,
                      aclDataType dataType, aclTensor**   tensor) {
      auto size = GetShapeSize(shape) * sizeof(T);
      // 调用aclrtMalloc申请device侧内存
      auto ret = aclrtMalloc(deviceAddr, size,  ACL_MEM_MALLOC_HUGE_FIRST);
      CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc  failed. ERROR: %d\n", ret); return ret);
      // 调用aclrtMalloc申请host侧内存
      ret = aclrtMalloc(hostAddr, size,   ACL_MEM_MALLOC_HUGE_FIRST);
      CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc  failed. ERROR: %d\n", ret); return ret);
      // 调用aclCreateTensor接口创建aclTensor
      *tensor = aclCreateTensor(shape.data(), shape.size  (), dataType, nullptr, 0, aclFormat::ACL_FORMAT_ND,
                                shape.data(), shape.size  (), *deviceAddr);
      // 调用aclrtMemcpy将host侧数据拷贝到device侧内存上
      ret = aclrtMemcpy(*deviceAddr, size, *hostAddr,   GetShapeSize(shape)*aclDataTypeSize(dataType),  ACL_MEMCPY_HOST_TO_DEVICE);
      CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMemcpy  failed. ERROR: %d\n", ret); return ret);
      return 0;
  }

  template <typename T>
  int CreateAclTensorNZ(const std::vector<T>& shape,  void** deviceAddr, void** hostAddr,
                      aclDataType dataType, aclTensor**   tensor) {
      auto size = GetShapeSize(shape) * sizeof(T);
      // 调用aclrtMalloc申请device侧内存
      auto ret = aclrtMalloc(deviceAddr, size,  ACL_MEM_MALLOC_HUGE_FIRST);
      CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc  failed. ERROR: %d\n", ret); return ret);
      // 调用aclrtMalloc申请host侧内存
      ret = aclrtMalloc(hostAddr, size,   ACL_MEM_MALLOC_HUGE_FIRST);
      CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc  failed. ERROR: %d\n", ret); return ret);
      // 调用aclCreateTensor接口创建aclTensor
      *tensor = aclCreateTensor(shape.data(), shape.size  (), dataType, nullptr, 0,   aclFormat::ACL_FORMAT_FRACTAL_NZ,
                                shape.data(), shape.size  (), *deviceAddr);
      // 调用aclrtMemcpy将host侧数据拷贝到device侧内存上
      ret = aclrtMemcpy(*deviceAddr, size, *hostAddr,   GetShapeSize(shape)*aclDataTypeSize(dataType),  ACL_MEMCPY_HOST_TO_DEVICE);
      CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMemcpy  failed. ERROR: %d\n", ret); return ret);
      return 0;
  }

  int TransToNZShape(std::vector<int64_t> &shapeND, size_t  typeSize) {
      int64_t h = shapeND[0];
      int64_t w = shapeND[1];
      int64_t h0 = 16;
      int64_t w0 = 32U / typeSize;
      int64_t h1 = h / h0;
      int64_t w1 = w / w0;
      shapeND[0] = w1;
      shapeND[1] = h1;
      shapeND.emplace_back(h0);
      shapeND.emplace_back(w0);
      return 0;
  }

  int main() {
      // 1. 固定写法,device/stream初始化, 参考AscendCL对外接 口列表
      // 根据自己的实际device填写deviceId
      int32_t deviceId = 0;
      aclrtStream stream;
      auto ret = Init(deviceId, &stream);
      // check根据自己的需要处理
      CHECK_RET(ret == 0, LOG_PRINT("Init acl failed.   ERROR: %d\n", ret); return ret);
      // 2. 构造输入与输出,需要根据API的接口定义构造
      std::vector<int64_t> tokenXShape = {8, 1,   7168};            // B,S,He
      std::vector<int64_t> weightDqShape = {7168,   1536};          // He,Hcq
      std::vector<int64_t> weightUqQrShape = {1536,   6144};        // Hcq,N*(D+Dr)
      std::vector<int64_t> weightUkShape = {32, 128,  512};         // N,D,Hckv
      std::vector<int64_t> weightDkvKrShape = {7168,  576};         // He,Hckv+Dr
      std::vector<int64_t> rmsnormGammaCqShape =  {1536};           // Hcq
      std::vector<int64_t> rmsnormGammaCkvShape =   {512};          // Hckv
      std::vector<int64_t> ropeSinShape = {8, 1,  64};              // B,S,Dr
      std::vector<int64_t> ropeCosShape = {8, 1,  64};              // B,S,Dr
      std::vector<int64_t> cacheIndexShape = {8,  1};               // B,S
      std::vector<int64_t> kvCacheShape = {16, 128, 1,  512};       // BolckNum,BlockSize,Nkv,Hckv
      std::vector<int64_t> krCacheShape = {16, 128, 1,  64};        // BolckNum,BlockSize,Nkv,Dr
      std::vector<int64_t> dequantScaleXShape = {8 ,  1};           // B*S, 1
      std::vector<int64_t> dequantScaleWDqShape = {1,   1536};      // 1, Hcq
      std::vector<int64_t> dequantScaleWUqQrShape = {1,   6144};    // 1, N*(D+Dr)
      std::vector<int64_t> dequantScaleWDkvKrShape = {1,  576};     // 1, Hckv+Dr
      std::vector<int64_t> quantScaleCkvShape = {1,   512};         // 1, Hckv
      std::vector<int64_t> smoothScaleCqShape = {1,   1536};        // 1, Hcq
      std::vector<int64_t> queryShape = {8, 1, 32,  512};           // B,S,N,Hckv
      std::vector<int64_t> queryRopeShape = {8, 1, 32,  64};        // B,S,N,Dr
      std::vector<int64_t> dequantScaleQNopeShape = {8,   32, 1};   // B*S, N, 1
      double rmsnormEpsilonCq = 1e-5;
      double rmsnormEpsilonCkv = 1e-5;
      char cacheMode[] = "PA_BSND";

      void* tokenXDeviceAddr = nullptr;
      void* weightDqDeviceAddr = nullptr;
      void* weightUqQrDeviceAddr = nullptr;
      void* weightUkDeviceAddr = nullptr;
      void* weightDkvKrDeviceAddr = nullptr;
      void* rmsnormGammaCqDeviceAddr = nullptr;
      void* rmsnormGammaCkvDeviceAddr = nullptr;
      void* ropeSinDeviceAddr = nullptr;
      void* ropeCosDeviceAddr = nullptr;
      void* cacheIndexDeviceAddr = nullptr;
      void* kvCacheDeviceAddr = nullptr;
      void* krCacheDeviceAddr = nullptr;
      void* dequantScaleXDeviceAddr = nullptr;
      void* dequantScaleWDqDeviceAddr = nullptr;
      void* dequantScaleWUqQrDeviceAddr = nullptr;
      void* dequantScaleWDkvKrDeviceAddr = nullptr;
      void* quantScaleCkvDeviceAddr = nullptr;
      void* smoothScaleCqDeviceAddr = nullptr;
      void* queryDeviceAddr = nullptr;
      void* queryRopeDeviceAddr = nullptr;
      void* dequantScaleQNopeDeviceAddr = nullptr;

      void* tokenXHostAddr = nullptr;
      void* weightDqHostAddr = nullptr;
      void* weightUqQrHostAddr = nullptr;
      void* weightUkHostAddr = nullptr;
      void* weightDkvKrHostAddr = nullptr;
      void* rmsnormGammaCqHostAddr = nullptr;
      void* rmsnormGammaCkvHostAddr = nullptr;
      void* ropeSinHostAddr = nullptr;
      void* ropeCosHostAddr = nullptr;
      void* cacheIndexHostAddr = nullptr;
      void* kvCacheHostAddr = nullptr;
      void* krCacheHostAddr = nullptr;
      void* dequantScaleXHostAddr = nullptr;
      void* dequantScaleWDqHostAddr = nullptr;
      void* dequantScaleWUqQrHostAddr = nullptr;
      void* dequantScaleWDkvKrHostAddr = nullptr;
      void* quantScaleCkvHostAddr = nullptr;
      void* smoothScaleCqHostAddr = nullptr;
      void* queryHostAddr = nullptr;
      void* queryRopeHostAddr = nullptr;
      void* dequantScaleQNopeHostAddr = nullptr;

      aclTensor* tokenX = nullptr;
      aclTensor* weightDq = nullptr;
      aclTensor* weightUqQr = nullptr;
      aclTensor* weightUk = nullptr;
      aclTensor* weightDkvKr = nullptr;
      aclTensor* rmsnormGammaCq = nullptr;
      aclTensor* rmsnormGammaCkv = nullptr;
      aclTensor* ropeSin = nullptr;
      aclTensor* ropeCos = nullptr;
      aclTensor* cacheIndex = nullptr;
      aclTensor* kvCache = nullptr;
      aclTensor* krCache = nullptr;
      aclTensor* dequantScaleX = nullptr;
      aclTensor* dequantScaleWDq = nullptr;
      aclTensor* dequantScaleWUqQr = nullptr;
      aclTensor* dequantScaleWDkvKr = nullptr;
      aclTensor* quantScaleCkv = nullptr;
      aclTensor* smoothScaleCq = nullptr;
      aclTensor* query = nullptr;
      aclTensor* queryRope = nullptr;
      aclTensor* dequantScaleQNope = nullptr;

      // 转换三个NZ格式变量的shape
      ret = TransToNZShape(weightDqShape, sizeof(int8_t));
      CHECK_RET(ret == 0, LOG_PRINT("trans NZ shape failed. \n"); return ret);
      ret = TransToNZShape(weightUqQrShape, sizeof  (int8_t));
      CHECK_RET(ret == 0, LOG_PRINT("trans NZ shape failed. \n"); return ret);
      ret = TransToNZShape(weightDkvKrShape, sizeof (int8_t));
      CHECK_RET(ret == 0, LOG_PRINT("trans NZ shape failed. \n"); return ret);

      // 创建tokenX aclTensor
      ret = CreateAclTensorND(tokenXShape, &  tokenXDeviceAddr, &tokenXHostAddr,  aclDataType::ACL_INT8, &tokenX);
      CHECK_RET(ret == ACL_SUCCESS, return ret);
      // 创建weightDq aclTensor
      ret = CreateAclTensorNZ(weightDqShape, &  weightDqDeviceAddr, &weightDqHostAddr,  aclDataType::ACL_INT8, &weightDq);
      CHECK_RET(ret == ACL_SUCCESS, return ret);
      // 创建weightUqQr aclTensor
      ret = CreateAclTensorNZ(weightUqQrShape, &  weightUqQrDeviceAddr, &weightUqQrHostAddr,  aclDataType::ACL_INT8, &weightUqQr);
      CHECK_RET(ret == ACL_SUCCESS, return ret);
      // 创建weightUk aclTensor
      ret = CreateAclTensorND(weightUkShape, &  weightUkDeviceAddr, &weightUkHostAddr,  aclDataType::ACL_BF16, &weightUk);
      CHECK_RET(ret == ACL_SUCCESS, return ret);
      // 创建weightDkvKr aclTensor
      ret = CreateAclTensorNZ(weightDkvKrShape, & weightDkvKrDeviceAddr, &weightDkvKrHostAddr,   aclDataType::ACL_INT8, &weightDkvKr);
      CHECK_RET(ret == ACL_SUCCESS, return ret);
      // 创建rmsnormGammaCq aclTensor
      ret = CreateAclTensorND(rmsnormGammaCqShape, &  rmsnormGammaCqDeviceAddr, &rmsnormGammaCqHostAddr,  aclDataType::ACL_BF16, &rmsnormGammaCq);
      CHECK_RET(ret == ACL_SUCCESS, return ret);
      // 创建rmsnormGammaCkv aclTensor
      ret = CreateAclTensorND(rmsnormGammaCkvShape, & rmsnormGammaCkvDeviceAddr, &rmsnormGammaCkvHostAddr,   aclDataType::ACL_BF16, &rmsnormGammaCkv);
      CHECK_RET(ret == ACL_SUCCESS, return ret);
      // 创建ropeSin aclTensor
      ret = CreateAclTensorND(ropeSinShape, & ropeSinDeviceAddr, &ropeSinHostAddr,   aclDataType::ACL_BF16, &ropeSin);
      CHECK_RET(ret == ACL_SUCCESS, return ret);
      // 创建ropeCos aclTensor
      ret = CreateAclTensorND(ropeCosShape, & ropeCosDeviceAddr, &ropeCosHostAddr,   aclDataType::ACL_BF16, &ropeCos);
      CHECK_RET(ret == ACL_SUCCESS, return ret);
      // 创建cacheIndex aclTensor
      ret = CreateAclTensorND(cacheIndexShape, &  cacheIndexDeviceAddr, &cacheIndexHostAddr,  aclDataType::ACL_INT64, &cacheIndex);
      CHECK_RET(ret == ACL_SUCCESS, return ret);
      // 创建kvCache aclTensor
      ret = CreateAclTensorND(kvCacheShape, & kvCacheDeviceAddr, &kvCacheHostAddr,   aclDataType::ACL_INT8, &kvCache);
      CHECK_RET(ret == ACL_SUCCESS, return ret);
      // 创建krCache aclTensor
      ret = CreateAclTensorND(krCacheShape, & krCacheDeviceAddr, &krCacheHostAddr,   aclDataType::ACL_BF16, &krCache);
      CHECK_RET(ret == ACL_SUCCESS, return ret);
      // 创建dequantScaleX aclTensor
      ret = CreateAclTensorND(dequantScaleXShape, & dequantScaleXDeviceAddr, &dequantScaleXHostAddr,   aclDataType::ACL_FLOAT, &dequantScaleX);
      CHECK_RET(ret == ACL_SUCCESS, return ret);
      // 创建dequantScaleWDq aclTensor
      ret = CreateAclTensorND(dequantScaleWDqShape, & dequantScaleWDqDeviceAddr, &dequantScaleWDqHostAddr,   aclDataType::ACL_FLOAT, &dequantScaleWDq);
      CHECK_RET(ret == ACL_SUCCESS, return ret);
      // 创建dequantScaleWUqQr aclTensor
      ret = CreateAclTensorND(dequantScaleWUqQrShape, & dequantScaleWUqQrDeviceAddr, & dequantScaleWUqQrHostAddr, aclDataType::ACL_FLOAT, & dequantScaleWUqQr);
      CHECK_RET(ret == ACL_SUCCESS, return ret);
      // 创建dequantScaleWDkvKr aclTensor
      ret = CreateAclTensorND(dequantScaleWDkvKrShape, &  dequantScaleWDkvKrDeviceAddr, & dequantScaleWDkvKrHostAddr, aclDataType::ACL_FLOAT, &  dequantScaleWDkvKr);
      CHECK_RET(ret == ACL_SUCCESS, return ret);
      // 创建quantScaleCkv aclTensor
      ret = CreateAclTensorND(quantScaleCkvShape, & quantScaleCkvDeviceAddr, &quantScaleCkvHostAddr,   aclDataType::ACL_FLOAT, &quantScaleCkv);
      CHECK_RET(ret == ACL_SUCCESS, return ret);
      // 创建smoothScaleCq aclTensor
      ret = CreateAclTensorND(smoothScaleCqShape, & smoothScaleCqDeviceAddr, &smoothScaleCqHostAddr,   aclDataType::ACL_FLOAT, &smoothScaleCq);
      CHECK_RET(ret == ACL_SUCCESS, return ret);
      // 创建query aclTensor
      ret = CreateAclTensorND(queryShape, & queryDeviceAddr, &queryHostAddr,   aclDataType::ACL_INT8, &query);
      CHECK_RET(ret == ACL_SUCCESS, return ret);
      // 创建queryRope aclTensor
      ret = CreateAclTensorND(queryRopeShape, & queryRopeDeviceAddr, &queryRopeHostAddr,   aclDataType::ACL_BF16, &queryRope);
      CHECK_RET(ret == ACL_SUCCESS, return ret);
      // 创建dequantScaleQNope aclTensor
      ret = CreateAclTensorND(dequantScaleQNopeShape, & dequantScaleQNopeDeviceAddr, & dequantScaleQNopeHostAddr, aclDataType::ACL_FLOAT, & dequantScaleQNope);
      CHECK_RET(ret == ACL_SUCCESS, return ret);

      // 3. 调用CANN算子库API,需要修改为具体的API
      uint64_t workspaceSize = 0;
      aclOpExecutor* executor = nullptr;
      // 调用aclnnMlaPrologV2WeightNz第一段接口
      ret = aclnnMlaPrologV2WeightNzGetWorkspaceSize  (tokenX, weightDq, weightUqQr, weightUk,  weightDkvKr, rmsnormGammaCq, rmsnormGammaCkv,  ropeSin, ropeCos, cacheIndex, kvCache, krCache,
        dequantScaleX, dequantScaleWDq, dequantScaleWUqQr,  dequantScaleWDkvKr, quantScaleCkv, nullptr,  smoothScaleCq, rmsnormEpsilonCq,   rmsnormEpsilonCkv, cacheMode,
        query, queryRope, dequantScaleQNope, &  workspaceSize, &executor);
      CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT ("aclnnMlaPrologV2WeightNzGetWorkspaceSize failed.   ERROR: %d\n", ret); return ret);
      // 根据第一段接口计算出的workspaceSize申请device内存
      void* workspaceAddr = nullptr;
      if (workspaceSize > 0) {
          ret = aclrtMalloc(&workspaceAddr, workspaceSize,  ACL_MEM_MALLOC_HUGE_FIRST);
          CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT ("allocate workspace failed. ERROR: %d\n", ret);   return ret;);
      }
      // 调用aclnnMlaPrologV2WeightNz第二段接口
      ret = aclnnMlaPrologV2WeightNz(workspaceAddr,   workspaceSize, executor, stream);
      CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT ("aclnnMlaPrologV2WeightNz failed. ERROR: %d\n",   ret); return ret);

      // 4. 固定写法,同步等待任务执行结束
      ret = aclrtSynchronizeStream(stream);
      CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT ("aclrtSynchronizeStream failed. ERROR: %d\n", ret);   return ret);

      // 5. 获取输出的值,将device侧内存上的结果拷贝至host侧,  需要根据具体API的接口定义修改
      auto size = GetShapeSize(queryShape);
      std::vector<float> resultData(size, 0);
      ret = aclrtMemcpy(resultData.data(), resultData.size  () * sizeof(resultData[0]), queryDeviceAddr, size *   sizeof(float),
                        ACL_MEMCPY_DEVICE_TO_HOST);
      CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("copy result  from device to host failed. ERROR: %d\n", ret);  return ret);

      // 6. 释放aclTensor和aclScalar,需要根据具体API的接口定 义修改
      aclDestroyTensor(tokenX);
      aclDestroyTensor(weightDq);
      aclDestroyTensor(weightUqQr);
      aclDestroyTensor(weightUk);
      aclDestroyTensor(weightDkvKr);
      aclDestroyTensor(rmsnormGammaCq);
      aclDestroyTensor(rmsnormGammaCkv);
      aclDestroyTensor(ropeSin);
      aclDestroyTensor(ropeCos);
      aclDestroyTensor(cacheIndex);
      aclDestroyTensor(kvCache);
      aclDestroyTensor(krCache);
      aclDestroyTensor(dequantScaleX);
      aclDestroyTensor(dequantScaleWDq);
      aclDestroyTensor(dequantScaleWUqQr);
      aclDestroyTensor(dequantScaleWDkvKr);
      aclDestroyTensor(quantScaleCkv);
      aclDestroyTensor(smoothScaleCq);
      aclDestroyTensor(query);
      aclDestroyTensor(queryRope);
      aclDestroyTensor(dequantScaleQNope);

      // 7. 释放device 资源
      aclrtFree(tokenXDeviceAddr);
      aclrtFree(weightDqDeviceAddr);
      aclrtFree(weightUqQrDeviceAddr);
      aclrtFree(weightUkDeviceAddr);
      aclrtFree(weightDkvKrDeviceAddr);
      aclrtFree(rmsnormGammaCqDeviceAddr);
      aclrtFree(rmsnormGammaCkvDeviceAddr);
      aclrtFree(ropeSinDeviceAddr);
      aclrtFree(ropeCosDeviceAddr);
      aclrtFree(cacheIndexDeviceAddr);
      aclrtFree(kvCacheDeviceAddr);
      aclrtFree(krCacheDeviceAddr);
      aclrtFree(dequantScaleXDeviceAddr);
      aclrtFree(dequantScaleWDqDeviceAddr);
      aclrtFree(dequantScaleWUqQrDeviceAddr);
      aclrtFree(dequantScaleWDkvKrDeviceAddr);
      aclrtFree(quantScaleCkvDeviceAddr);
      aclrtFree(smoothScaleCqDeviceAddr);
      aclrtFree(queryDeviceAddr);
      aclrtFree(queryRopeDeviceAddr);
      aclrtFree(dequantScaleQNopeDeviceAddr);

      // 8. 释放host 资源
      aclrtFree(tokenXHostAddr);
      aclrtFree(weightDqHostAddr);
      aclrtFree(weightUqQrHostAddr);
      aclrtFree(weightUkHostAddr);
      aclrtFree(weightDkvKrHostAddr);
      aclrtFree(rmsnormGammaCqHostAddr);
      aclrtFree(rmsnormGammaCkvHostAddr);
      aclrtFree(ropeSinHostAddr);
      aclrtFree(ropeCosHostAddr);
      aclrtFree(cacheIndexHostAddr);
      aclrtFree(kvCacheHostAddr);
      aclrtFree(krCacheHostAddr);
      aclrtFree(dequantScaleXHostAddr);
      aclrtFree(dequantScaleWDqHostAddr);
      aclrtFree(dequantScaleWUqQrHostAddr);
      aclrtFree(dequantScaleWDkvKrHostAddr);
      aclrtFree(quantScaleCkvHostAddr);
      aclrtFree(smoothScaleCqHostAddr);
      aclrtFree(queryHostAddr);
      aclrtFree(queryRopeHostAddr);
      aclrtFree(dequantScaleQNopeHostAddr);

      if (workspaceSize > 0) {
        aclrtFree(workspaceAddr);
      }
      aclrtDestroyStream(stream);
      aclrtResetDevice(deviceId);
      aclFinalize();

      return 0;
  }