昇腾社区首页
中文
注册
开发者
下载

Device侧代码与说明

Device侧代码示例如下,并参考本节内容了解主要操作。
// Device侧文件:cube_kernel.cce
#ifdef __CCE_KT_TEST__
#define __aicore__ 
#else
#define __aicore__ [aicore]
#endif

extern "C" __global__ __aicore__ void mat_mul_kernel(__gm__ uint8_t *tensor_a, __gm__ uint8_t *tensor_b, __gm__ uint8_t *tensor_c, int m, int k, int n) {

    // 指针类型转换,将通用的字节指针(uint8_t*)转换为特定类型的浮点指针(float*)
    __gm__ float *aGm = (__gm__ float *)tensor_a;
    __gm__ float *bGm = (__gm__ float *)tensor_b;
    __gm__ float *cGm = (__gm__ float *)tensor_c;

    // 初始化三个需要用到的tensor首地址,分别在L0A、L0B和L0C中
    __ca__ float *aL0a = (__ca__ float *)get_imm(0);
    __cb__ float *bL0b = (__cb__ float *)get_imm(0);
    __cc__ float *cL0c = (__cc__ float *)get_imm(0);
    
    // 数据搬入
    load_gm_to_ca(aL0a, aGm, 0, 8, 1, 0, 0, inc); 
    load_gm_to_cb(bL0b, bGm, 0, 8, 1, 0, 0, inc); 
    
    // 同步
    set_flag(PIPE_MTE2, PIPE_M, EVENT_ID0);
    wait_flag(PIPE_MTE2, PIPE_M, EVENT_ID0);

    // 计算
    mad(cL0c, aL0a, bL0b, m, k, n, 0, 0, 0, 0, 0, 0, true);

    // 同步
    set_flag(PIPE_M, PIPE_FIX, EVENT_ID1);
    wait_flag(PIPE_M, PIPE_FIX, EVENT_ID1);
    
    // 数据搬出
    set_nd_para(4398046576641);
    copy_matrix_cc_to_gm(cGm, cL0c, 0, n, m, n, m, 0, 0, 0, false, true);

    // 同步
    pipe_barrier(PIPE_ALL);
}
  1. 数据搬入,该接口不具备分形转换能力,所以需要GM中的数据已经是ZZ和ZN格式,在本示例中由Host侧完成分形转换。

    将需要进行矩阵计算的两个tensor,从GM搬到L0A/L0B上,这里使用了load_gm_to_ca/load_gm_to_cb接口,其接口原型为:

    void load_gm_to_ca(__ca__ float *dst, __gm__ float *src, uint16_t baseIdx, uint8_t repeat, uint16_t srcStride, uint16_t dstGap, uint8_t sid, __cce_scalar::addr_cal_mode_t addr_cal_mode);
    
    void load_gm_to_cb(__cb__ float *dst, __gm__ float *src, uint16_t baseIdx, uint8_t repeat, uint16_t srcStride, uint16_t dstGap, uint8_t sid, __cce_scalar::addr_cal_mode_t addr_cal_mode);
    • dst:目的地址,aL0a/bL0b。
    • src:源地址,aGm/bGm。
    • baseIdx:表示src matrix的fractal matrix的index ID。从分形0开始读取,所以设为0。
    • repeat:重复搬运的次数。每次搬512字节的分形,待搬运矩阵为32*32*sizeof(float)=4096B,所以repeat=4096/512=8。
    • srcStride:每个repeat迭代的源地址stride,单位是16*16的fractal matrix,每个repeat迭代的源地址分形头到第二个分形头的距离。连续读取时取1。
    • dstGap:每个repeat迭代的目的地址gap,单位是16*16的fractal matrix,每个repeat迭代的目的地址分形尾到第二个分形头的距离。连续存储时取0。
    • sid:预留参数,此处默认为0即可。
    • addr_cal_mode:设为inc,表示src地址是增加,数据连续往后顺序读取;设为dec,表示从src往前读取,地址减小。需要递增读取,设为inc。
  2. 计算。

    将在L0A/L0B上的两个矩阵进行矩阵相乘,这里使用mad接口,其接口原型为:

    void mad(__cc__ float *dst, __ca__ float *src0, __cb__ float *src1, uint16_t m, uint16_t k, uint16_t n, uint8_t featOffset, uint8_t smaskOffset, uint8_t unitFlag, bool kDirectionAlign, bool isWeightOffset, bool cmatrixSource, bool cmatrixInitVal);
    • c:目的地址,cL0c。
    • a:矩阵A地址,aL0a。
    • b:矩阵B地址,bL0ab。
    • m:A矩阵的高,传入参数为m(=32)。
    • k:A矩阵的宽或B矩阵的高,传入参数为k(=32)。
    • n:B矩阵的宽,传入参数为n(=32)。
    • featOffset:特征图矩阵偏移。该参数在当前版本未启用,设为0即可。
    • smaskOffset:SMASK缓冲区地址。该参数在当前版本未启用,设为0即可。
    • unitFlag:设置为0即可。
    • kDirectionAlign:此位仅用于(源矩阵类型为f32, 并且目的矩阵类型也是f32,即f32*f32+f32),其他类型忽略此位。如果=1,则L0AL0B中的矩阵在K方向上对齐到16。否则它将与8对齐。设为0即可。
    • isWeightOffset:权重矩阵偏移使能位。该参数在当前版本未启用,设为0即可。
    • cmatrixcSource:当cmatrixInitVal设为False时,该参数才有意义。如果cmatrixSource=0,则C矩阵在L0C中,其在L0C中的地址与c[31:0]中的地址相同。如果cmatrixSource=1,则C矩阵位于偏置表中。设为0即可。
    • cmatrixInitVal:表示矩阵c初始值控制位,True表示c矩阵初始值为0,False使用c矩阵初始值使用其中的具体数据。设置为true,使得C矩阵初始值为0。
  3. 数据搬出。

    将计算完的结果tensor,从L0C上搬出结果到GM指定位置,并实现NZ2ND的分形转换,这里使用了copy_matrix_cc_to_gm接口,同时为了实现NZ2ND,还需要使用set_nd_para对进行设置。

    copy_matrix_cc_to_gm接口的原型为:

    void copy_matrix_cc_to_gm(__gm__ float *dst, __cc__ float *src, uint8_t sid, uint16_t NSize, uint16_t MSize, uint32_t dstStride_dst_D, uint16_t srcStride, uint8_t UnitFlagMode, uint64_t QuantPRE, uint8_t ReLUPRE, bool channelSplit, bool NZ2ND_EN);
    • dst:目的地址,cGm。
    • src:源地址,cL0c。
    • sid:预留参数,此处默认为0即可。
    • NSize:矩阵C的宽,传入参数为n(=32)。
    • MSize:矩阵C的高,传入参数为m(=32)。
    • dstStride_dst_D:因为使能了NZ2ND,且目的ND矩阵为连续存储,所以该参数为矩阵C的宽,传入参数为n(=32)。
    • srcStride:L0C源矩阵中不同数据块间距离。连续读取时为矩阵C的高,传入参数为m(=32)。
    • UnitFlagMode:填0即可。
    • QuantPRE:预量化模式,预留参数,设为0即可,代表不做量化。
    • ReLUPRE:ReLU模式。不使用ReLU,设为0。
    • channelSpilt:是否使能通道拆分。不进行通道切分,设为0。
    • NZ2ND_EN:是否使能NZ2ND_EN格式转换。设为1,实现NZ2ND的转换。

    set_nd_para接口的原型为:

    void set_nd_para(uint64_t config);
    • config[0:15]位:表示nd块数量。nd块数量为1。
    • config[16:31]位:表示源数据nd块步长,其单位为分形大小。连续读取为1。
    • config[32:47]位:表示目的数据nd块步长,其单位为元素。连续存储为32*32=1024。
    • config[0:47]为:0100 0000 0000 0000 0000 0000 0001 0000 0000 0000 0001,其十进制为4398046576641。