GatherMask
产品支持情况
产品 |
是否支持 |
|---|---|
Ascend 950PR/Ascend 950DT |
√ |
√ |
|
√ |
|
√ |
|
√ |
|
x |
|
x |
功能说明
以内置固定模式对应的二进制或者用户自定义输入的Tensor数值对应的二进制为gather mask(数据收集的掩码),从源操作数中选取元素写入目的操作数中。
函数原型
- 用户自定义模式
1 2
template <typename T, typename U, GatherMaskMode mode = defaultGatherMaskMode> __aicore__ inline void GatherMask(const LocalTensor<T>& dst, const LocalTensor<T>& src0, const LocalTensor<U>& src1Pattern, const bool reduceMode, const uint32_t mask, const GatherMaskParams& gatherMaskParams, uint64_t& rsvdCnt)
- 内置固定模式
1 2
template <typename T, GatherMaskMode mode = defaultGatherMaskMode> __aicore__ inline void GatherMask(const LocalTensor<T>& dst, const LocalTensor<T>& src0, const uint8_t src1Pattern, const bool reduceMode, const uint32_t mask, const GatherMaskParams& gatherMaskParams, uint64_t& rsvdCnt)
参数说明
参数名称 |
含义 |
|---|---|
T |
源操作数src0和目的操作数dst的数据类型。 Ascend 950PR/Ascend 950DT,支持的数据类型为:int8_t/uint8_t/int16_t/uint16_t/half/bfloat16_t/float/int32_t/uint32_t |
U |
用户自定义模式下src1Pattern的数据类型。支持的数据类型为uint8_t/uint16_t/uint32_t。
|
mode |
预留参数,为后续功能做预留,当前提供默认值,用户无需设置该参数。 |
参数名称 |
输入/输出 |
含义 |
|---|---|---|
dst |
输出 |
目的操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 LocalTensor的起始地址需要32字节对齐。 |
src0 |
输入 |
源操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 LocalTensor的起始地址需要32字节对齐。 数据类型需要与目的操作数保持一致。 |
src1Pattern |
输入 |
gather mask(数据收集的掩码),分为内置固定模式和用户自定义模式两种,根据内置固定模式对应的二进制或者用户自定义输入的Tensor数值对应的二进制从源操作数中选取元素写入目的操作数中。1为选取,0为不选取。
|
reduceMode |
输入 |
用于选择mask参数模式,数据类型为bool,支持如下取值。
|
mask |
输入 |
用于控制每次迭代内参与计算的元素。根据reduceMode,分为两种模式:
|
gatherMaskParams |
输入 |
控制操作数地址步长的数据结构,GatherMaskParams类型。 具体定义请参考${INSTALL_DIR}/include/ascendc/basic_api/interface/kernel_struct_gather.h,${INSTALL_DIR}请替换为CANN软件安装后文件存储路径。 具体参数说明表3。 |
rsvdCnt |
输出 |
该条指令筛选后保留下来的元素计数,对应dstLocal中有效元素个数,数据类型为uint64_t。 |
返回值说明
无
约束说明
- 若调用该接口前为Counter模式,在调用该接口后需要显式设置回Counter模式(接口内部执行结束后会设置为Normal模式)。
调用示例
- 用户自定义Tensor样例示例。
1 2 3 4 5 6 7 8 9 10 11 12 13
uint32_t mask = 70; // 每次迭代内参与计算的元素 uint64_t rsvdCnt = 0; // 保留下来的元素个数 // src0Local:源操作数 // src1Local:存放数据收集的掩码的Tensor // dstLocal:目的操作数 // reduceMode = true; 使用Counter模式 // {}中的参数为: // src0BlockStride = 1; 单次迭代内数据间隔1个datablock,即数据连续读取和写入 // repeatTimes = 2; Counter模式时,仅在部分产品型号下会生效 // src0RepeatStride = 4; 源操作数迭代间数据间隔4个datablock // src1RepeatStride = 0; src1迭代间数据间隔0个datablock,即原位置读取 AscendC::GatherMask (dstLocal, src0Local, src1Local, true, mask, { 1, 2, 4, 0 }, rsvdCnt);
下图为Counter模式配置方式一示意图:
- mask = 70,每一次repeat计算70个元素;
- repeatTimes = 2,共进行2次repeat;
- src0BlockStride = 1,源操作数src0Local单次迭代内datablock之间无间隔;
- src0RepeatStride = 4,源操作数src0Local相邻迭代间的间隔为4个datablock,所以第二次repeat从第33个元素开始处理。
- src1Pattern配置为用户自定义模式。src1RepeatStride = 0,src1Pattern相邻迭代间的间隔为0个datablock,所以第二次repeat仍从src1Pattern的首地址开始处理。
下图为Counter模式配置方式二示意图:
- mask = 70,一共计算70个元素;
- repeatTimes配置不生效,根据源操作数和mask自动推断:源操作数的数据类型为uint32_t,每个迭代处理256Bytes数据,一个迭代处理64个元素,共需要进行2次repeat;
- src0BlockStride = 1,源操作数src0Local单次迭代内datablock之间无间隔;
- src0RepeatStride = 4,源操作数src0Local相邻迭代间的间隔为4个datablock,所以第二次repeat从第33个元素开始处理。
- src1Pattern配置为用户自定义模式。src1RepeatStride = 0,src1Pattern相邻迭代间的间隔为0个datablock,所以第二次repeat仍从src1Pattern的首地址开始处理。
- 内置固定模式样例示例。
1 2 3 4 5 6 7 8 9 10 11 12 13 14
uint32_t mask = 0; // 每次迭代内参与计算的元素,normal模式下mask建议设置为0 uint64_t rsvdCnt = 0; // 用于保存筛选后保留下来的元素个数 uint8_t src1Pattern = 2; // 内置固定模式 // src0Local:源操作数 // src1Pattern :存放数据收集的掩码的Tensor // src0Local:目的地址与源地址复用 // reduceMode = false; 使用normal模式 // {}中的参数为: // src0BlockStride = 1; 单次迭代内数据间隔1个Block,即数据连续读取和写入 // repeatTimes = 4;重复迭代4次 // src0RepeatStride = 8;源操作数迭代数据间隔8个datablock // src1RepeatStride = 0;重复一次,故设置为0 AscendC::GatherMask(src0Local, src0Local, src1Pattern, false, mask, { 1, 4, 8, 0 }, rsvdCnt);
结果示例如下:
输入数据src0Local:[1 2 3 ... 256] 输入数据src1Pattern:src1Pattern = 2; 输出数据dstLocal:[ 2 4 6 8 10 12 14 16 18 20 22 24 26 28 30 32 34 36 38 40 42 44 46 48 50 52 54 56 58 60 62 64 66 68 70 72 74 76 78 80 82 84 86 88 90 92 94 96 98 100 102 104 106 108 110 112 114 116 118 120 122 124 126 128 130 132 134 136 138 140 142 144 146 148 150 152 154 156 158 160 162 164 166 168 170 172 174 176 178 180 182 184 186 188 190 192 194 196 198 200 202 204 206 208 210 212 214 216 218 220 222 224 226 228 230 232 234 236 238 240 242 244 246 248 250 252 254 256 undefined ..undefined] 输出数据rsvdCnt:128

