GatherMask
产品支持情况
产品 |
是否支持 |
|---|---|
Atlas 350 加速卡 |
√ |
√ |
|
√ |
|
√ |
|
√ |
|
x |
|
x |
功能说明
以内置固定模式对应的二进制或者用户自定义输入的Tensor数值对应的二进制为gather mask(数据收集的掩码),从源操作数中选取元素写入目的操作数中。
函数原型
- 用户自定义模式
1 2
template <typename T, typename U, GatherMaskMode mode = defaultGatherMaskMode> __aicore__ inline void GatherMask(const LocalTensor<T>& dst, const LocalTensor<T>& src0, const LocalTensor<U>& src1Pattern, const bool reduceMode, const uint32_t mask, const GatherMaskParams& gatherMaskParams, uint64_t& rsvdCnt)
- 内置固定模式
1 2
template <typename T, GatherMaskMode mode = defaultGatherMaskMode> __aicore__ inline void GatherMask(const LocalTensor<T>& dst, const LocalTensor<T>& src0, const uint8_t src1Pattern, const bool reduceMode, const uint32_t mask, const GatherMaskParams& gatherMaskParams, uint64_t& rsvdCnt)
参数说明
参数名称 |
含义 |
|---|---|
T |
源操作数src0和目的操作数dst的数据类型。 Atlas 350 加速卡,支持的数据类型为:int8_t/uint8_t/int16_t/uint16_t/half/bfloat16_t/float/int32_t/uint32_t |
U |
用户自定义模式下src1Pattern的数据类型。支持的数据类型为uint8_t/uint16_t/uint32_t。
|
mode |
预留参数,为后续功能做预留,当前提供默认值,用户无需设置该参数。 |
参数名称 |
输入/输出 |
含义 |
|---|---|---|
dst |
输出 |
目的操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 LocalTensor的起始地址需要32字节对齐。 |
src0 |
输入 |
源操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 LocalTensor的起始地址需要32字节对齐。 数据类型需要与目的操作数保持一致。 |
src1Pattern |
输入 |
gather mask(数据收集的掩码),分为内置固定模式和用户自定义模式两种,根据内置固定模式对应的二进制或者用户自定义输入的Tensor数值对应的二进制从源操作数中选取元素写入目的操作数中。1为选取,0为不选取。
|
reduceMode |
输入 |
用于选择mask参数模式,数据类型为bool,支持如下取值。
|
mask |
输入 |
用于控制每次迭代内参与计算的元素。根据reduceMode,分为两种模式:
|
gatherMaskParams |
输入 |
控制操作数地址步长的数据结构,GatherMaskParams类型。 具体定义请参考${INSTALL_DIR}/include/ascendc/basic_api/interface/kernel_struct_gather.h,${INSTALL_DIR}请替换为CANN软件安装后文件存储路径。 具体参数说明表3。 |
rsvdCnt |
输出 |
该条指令筛选后保留下来的元素计数,对应dstLocal中有效元素个数,数据类型为uint64_t。 |
返回值说明
无
约束说明
- 若调用该接口前为Counter模式,在调用该接口后需要显式设置回Counter模式(接口内部执行结束后会设置为Normal模式)。
调用示例
- 用户自定义Tensor样例示例。
1 2 3 4 5 6 7 8 9 10 11 12 13
uint32_t mask = 70; // 每次迭代内参与计算的元素 uint64_t rsvdCnt = 0; // 保留下来的元素个数 // src0Local:源操作数 // src1Local:存放数据收集的掩码的Tensor // dstLocal:目的操作数 // reduceMode = true; 使用Counter模式 // {}中的参数为: // src0BlockStride = 1; 单次迭代内数据间隔1个datablock,即数据连续读取和写入 // repeatTimes = 2; Counter模式时,仅在部分产品型号下会生效 // src0RepeatStride = 4; 源操作数迭代间数据间隔4个datablock // src1RepeatStride = 0; src1迭代间数据间隔0个datablock,即原位置读取 AscendC::GatherMask (dstLocal, src0Local, src1Local, true, mask, { 1, 2, 4, 0 }, rsvdCnt);
下图为Counter模式配置方式一示意图:
- mask = 70,每一次repeat计算70个元素;
- repeatTimes = 2,共进行2次repeat;
- src0BlockStride = 1,源操作数src0Local单次迭代内datablock之间无间隔;
- src0RepeatStride = 4,源操作数src0Local相邻迭代间的间隔为4个datablock,所以第二次repeat从第33个元素开始处理。
- src1Pattern配置为用户自定义模式。src1RepeatStride = 0,src1Pattern相邻迭代间的间隔为0个datablock,所以第二次repeat仍从src1Pattern的首地址开始处理。
下图为Counter模式配置方式二示意图:
- mask = 70,一共计算70个元素;
- repeatTimes配置不生效,根据源操作数和mask自动推断:源操作数的数据类型为uint32_t,每个迭代处理256Bytes数据,一个迭代处理64个元素,共需要进行2次repeat;
- src0BlockStride = 1,源操作数src0Local单次迭代内datablock之间无间隔;
- src0RepeatStride = 4,源操作数src0Local相邻迭代间的间隔为4个datablock,所以第二次repeat从第33个元素开始处理。
- src1Pattern配置为用户自定义模式。src1RepeatStride = 0,src1Pattern相邻迭代间的间隔为0个datablock,所以第二次repeat仍从src1Pattern的首地址开始处理。
- 内置固定模式样例示例。
1 2 3 4 5 6 7 8 9 10 11 12 13 14
uint32_t mask = 0; // 每次迭代内参与计算的元素,normal模式下mask建议设置为0 uint64_t rsvdCnt = 0; // 用于保存筛选后保留下来的元素个数 uint8_t src1Pattern = 2; // 内置固定模式 // src0Local:源操作数 // src1Pattern :存放数据收集的掩码的Tensor // dstLocal:目的操作数 // reduceMode = false; 使用normal模式 // {}中的参数为: // src0BlockStride = 1; 单次迭代内数据间隔1个Block,即数据连续读取和写入 // repeatTimes = 1;重复迭代一次 // src0RepeatStride = 0;重复一次,故设置为0 // src1RepeatStride = 0;重复一次,故设置为0 AscendC::GatherMask(dstLocal, src0Local, src1Pattern, false, mask, { 1, 1, 0, 0 }, rsvdCnt);
结果示例如下:
输入数据src0Local:[1 2 3 ... 128] 输入数据src1Pattern:src1Pattern = 2; 输出数据dstLocal:[2 4 6 8 10 12 14 16 18 20 22 24 26 28 30 32 34 36 38 40 42 44 46 48 50 52 54 56 58 60 62 64 66 68 70 72 74 76 78 80 82 84 86 88 90 92 94 96 98 100 102 104 106 108 110 112 114 116 118 120 122 124 126 128 undefined ..undefined] 输出数据rsvdCnt:64

