Device侧代码与说明
Device侧代码示例如下,并参考本节内容了解主要操作。
// Device侧文件:cube_kernel.cce
#ifdef __CCE_KT_TEST__
#define __aicore__
#else
#define __aicore__ [aicore]
#endif
extern "C" __global__ __aicore__ void mat_mul_kernel(__gm__ uint8_t *tensor_a, __gm__ uint8_t *tensor_b, __gm__ uint8_t *tensor_c, int m, int k, int n) {
// 指针类型转换,将通用的字节指针(uint8_t*)转换为特定类型的浮点指针(float*)
__gm__ float *aGm = (__gm__ float *)tensor_a;
__gm__ float *bGm = (__gm__ float *)tensor_b;
__gm__ float *cGm = (__gm__ float *)tensor_c;
// 初始化三个需要用到的tensor首地址,分别在L0A、L0B和L0C中
__ca__ float *aL0a = (__ca__ float *)get_imm(0);
__cb__ float *bL0b = (__cb__ float *)get_imm(0);
__cc__ float *cL0c = (__cc__ float *)get_imm(0);
// 数据搬入
load_gm_to_ca(aL0a, aGm, 0, 8, 1, 0, 0, inc);
load_gm_to_cb(bL0b, bGm, 0, 8, 1, 0, 0, inc);
// 同步
set_flag(PIPE_MTE2, PIPE_M, EVENT_ID0);
wait_flag(PIPE_MTE2, PIPE_M, EVENT_ID0);
// 计算
mad(cL0c, aL0a, bL0b, m, k, n, 0, 0, 0, 0, 0, 0, true);
// 同步
set_flag(PIPE_M, PIPE_FIX, EVENT_ID1);
wait_flag(PIPE_M, PIPE_FIX, EVENT_ID1);
// 数据搬出
set_nd_para(4398046576641);
copy_matrix_cc_to_gm(cGm, cL0c, 0, n, m, n, m, 0, 0, 0, false, true);
// 同步
pipe_barrier(PIPE_ALL);
}
- 数据搬入,该接口不具备分形转换能力,所以需要GM中的数据已经是ZZ和ZN格式,在本示例中由Host侧完成分形转换。
将需要进行矩阵计算的两个tensor,从GM搬到L0A/L0B上,这里使用了load_gm_to_ca/load_gm_to_cb接口,其接口原型为:
void load_gm_to_ca(__ca__ float *dst, __gm__ float *src, uint16_t baseIdx, uint8_t repeat, uint16_t srcStride, uint16_t dstGap, uint8_t sid, __cce_scalar::addr_cal_mode_t addr_cal_mode); void load_gm_to_cb(__cb__ float *dst, __gm__ float *src, uint16_t baseIdx, uint8_t repeat, uint16_t srcStride, uint16_t dstGap, uint8_t sid, __cce_scalar::addr_cal_mode_t addr_cal_mode);
- dst:目的地址,aL0a/bL0b。
- src:源地址,aGm/bGm。
- baseIdx:表示src matrix的fractal matrix的index ID。从分形0开始读取,所以设为0。
- repeat:重复搬运的次数。每次搬512字节的分形,待搬运矩阵为32*32*sizeof(float)=4096B,所以repeat=4096/512=8。
- srcStride:每个repeat迭代的源地址stride,单位是16*16的fractal matrix,每个repeat迭代的源地址分形头到第二个分形头的距离。连续读取时取1。
- dstGap:每个repeat迭代的目的地址gap,单位是16*16的fractal matrix,每个repeat迭代的目的地址分形尾到第二个分形头的距离。连续存储时取0。
- sid:预留参数,此处默认为0即可。
- addr_cal_mode:设为inc,表示src地址是增加,数据连续往后顺序读取;设为dec,表示从src往前读取,地址减小。需要递增读取,设为inc。
- 计算。
将在L0A/L0B上的两个矩阵进行矩阵相乘,这里使用mad接口,其接口原型为:
void mad(__cc__ float *dst, __ca__ float *src0, __cb__ float *src1, uint16_t m, uint16_t k, uint16_t n, uint8_t featOffset, uint8_t smaskOffset, uint8_t unitFlag, bool kDirectionAlign, bool isWeightOffset, bool cmatrixSource, bool cmatrixInitVal);
- c:目的地址,cL0c。
- a:矩阵A地址,aL0a。
- b:矩阵B地址,bL0ab。
- m:A矩阵的高,传入参数为m(=32)。
- k:A矩阵的宽或B矩阵的高,传入参数为k(=32)。
- n:B矩阵的宽,传入参数为n(=32)。
- featOffset:特征图矩阵偏移。该参数在当前版本未启用,设为0即可。
- smaskOffset:SMASK缓冲区地址。该参数在当前版本未启用,设为0即可。
- unitFlag:设置为0即可。
- kDirectionAlign:此位仅用于(源矩阵类型为f32, 并且目的矩阵类型也是f32,即f32*f32+f32),其他类型忽略此位。如果=1,则L0AL0B中的矩阵在K方向上对齐到16。否则它将与8对齐。设为0即可。
- isWeightOffset:权重矩阵偏移使能位。该参数在当前版本未启用,设为0即可。
- cmatrixcSource:当cmatrixInitVal设为False时,该参数才有意义。如果cmatrixSource=0,则C矩阵在L0C中,其在L0C中的地址与c[31:0]中的地址相同。如果cmatrixSource=1,则C矩阵位于偏置表中。设为0即可。
- cmatrixInitVal:表示矩阵c初始值控制位,True表示c矩阵初始值为0,False使用c矩阵初始值使用其中的具体数据。设置为true,使得C矩阵初始值为0。
- 数据搬出。
将计算完的结果tensor,从L0C上搬出结果到GM指定位置,并实现NZ2ND的分形转换,这里使用了copy_matrix_cc_to_gm接口,同时为了实现NZ2ND,还需要使用set_nd_para对进行设置。
copy_matrix_cc_to_gm接口的原型为:
void copy_matrix_cc_to_gm(__gm__ float *dst, __cc__ float *src, uint8_t sid, uint16_t NSize, uint16_t MSize, uint32_t dstStride_dst_D, uint16_t srcStride, uint8_t UnitFlagMode, uint64_t QuantPRE, uint8_t ReLUPRE, bool channelSplit, bool NZ2ND_EN);
- dst:目的地址,cGm。
- src:源地址,cL0c。
- sid:预留参数,此处默认为0即可。
- NSize:矩阵C的宽,传入参数为n(=32)。
- MSize:矩阵C的高,传入参数为m(=32)。
- dstStride_dst_D:因为使能了NZ2ND,且目的ND矩阵为连续存储,所以该参数为矩阵C的宽,传入参数为n(=32)。
- srcStride:L0C源矩阵中不同数据块间距离。连续读取时为矩阵C的高,传入参数为m(=32)。
- UnitFlagMode:填0即可。
- QuantPRE:预量化模式,预留参数,设为0即可,代表不做量化。
- ReLUPRE:ReLU模式。不使用ReLU,设为0。
- channelSpilt:是否使能通道拆分。不进行通道切分,设为0。
- NZ2ND_EN:是否使能NZ2ND_EN格式转换。设为1,实现NZ2ND的转换。
set_nd_para接口的原型为:
void set_nd_para(uint64_t config);
- config[0:15]位:表示nd块数量。nd块数量为1。
- config[16:31]位:表示源数据nd块步长,其单位为分形大小。连续读取为1。
- config[32:47]位:表示目的数据nd块步长,其单位为元素。连续存储为32*32=1024。
- config[0:47]为:0100 0000 0000 0000 0000 0000 0001 0000 0000 0000 0001,其十进制为4398046576641。
父主题: 示例1