SelectWithBytesMask
函数功能
给定两个源操作数src0和src1,根据maskTensor相应位置的值(非bit位)选取元素,得到目的操作数dst。选择的规则为:当Mask的值为0时,从src0中选取,否则从src1选取。
该接口支持多维Shape,需满足maskTensor和源操作数Tensor的前轴(非尾轴)元素个数相同,且maskTensor尾轴元素个数大于等于源操作数尾轴元素个数,maskTensor多余部分丢弃不参与计算。
- maskTensor尾轴需32字节对齐且元素个数为16的倍数。
- 源操作数Tensor尾轴需32字节对齐。
如下图样例,源操作数src0为Tensor,shape为(2,16),数据类型为half,尾轴长度满足32字节对齐;源操作数src1为scalar,数据类型为half;maskTensor的数据类型为bool,为满足对齐要求shape为(2,32),仅有图中蓝色部分的mask掩码生效,灰色部分不参与计算。输出目的操作数dstTensor如下图所示。
实现原理
以float类型,ND格式,shape为[m, k1]的source输入Tensor,shape为[m, k2]的mask Tensor为例,描述SelectWithBytesMask高阶API内部算法框图,如下图所示。

计算过程分为如下几步,均在Vector上进行:
- GatherMask步骤:如果k1, k2不相等,则根据src的shape[m, k1],对输入mask[m, k2]通过GatherMask进行reduce计算,使得mask的k轴多余部分被舍去,shape转换为[m, k1];
- Cast步骤:将上一步的mask结果cast成half类型;
- Compare步骤:使用Compare接口将上一步的mask结果与0进行比较,得到cmpmask结果;
- Select步骤:根据cmpmask的结果,选择srcTensor相应位置的值或者scalar值,输出Output。
函数原型
- src0为srcTensor(tensor类型),src1为srcScalar(scalar类型)
1 2
template <typename T, typename U, bool isReuseMask = true> __aicore__ inline void SelectWithBytesMask(const LocalTensor<T> &dst, const LocalTensor<T> &src0, T src1, const LocalTensor<U> &mask, const LocalTensor<uint8_t> &sharedTmpBuffer, const SelectWithBytesMaskShapeInfo &info)
- src0为srcScalar(scalar类型),src1为srcTensor(tensor类型)
1 2
template <typename T, typename U, bool isReuseMask = true> __aicore__ inline void SelectWithBytesMask(const LocalTensor<T> &dst, T src0, const LocalTensor<T> &src1, const LocalTensor<U> &mask, const LocalTensor<uint8_t> &sharedTmpBuffer, const SelectWithBytesMaskShapeInfo &info)
该接口需要额外的临时空间来存储计算过程中的中间变量。临时空间需要开发者申请并通过sharedTmpBuffer入参传入。临时空间大小BufferSize的获取方式如下:通过GetSelectWithBytesMaskMaxMinTmpSize中提供的接口获取需要预留空间范围的大小。
参数说明
参数名 |
描述 |
---|---|
T |
操作数的数据类型。 |
U |
掩码Tensor mask的数据类型。 |
isReuseMask |
是否允许修改maskTensor。默认为True。 取值为True时,仅在maskTensor尾轴元素个数和srcTensor尾轴元素个数不同的情况下,maskTensor可能会被修改;其余场景,maskTensor不会修改。 为False时,任意场景下,maskTensor均不会修改,但可能会需要更多的临时空间。 |
参数名称 |
输入/输出 |
含义 |
||
---|---|---|---|---|
dst |
输出 |
目的操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 |
||
src0(srcTensor) src1(srcTensor) |
输入 |
源操作数。源操作数Tensor尾轴需32字节对齐。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 |
||
src1(srcScalar) src0(srcScalar) |
输入 |
源操作数。类型为scalar。 |
||
mask |
输入 |
掩码Tensor。用于描述如何选择srcTensor和srcScalar之间的值。maskTensor尾轴需32字节对齐且元素个数为16的倍数。 取值为0x00/0x01。
|
||
sharedTmpBuffer |
输入 |
该API用于计算的临时空间,所需空间大小根据GetSelectWithBytesMaskMaxMinTmpSize获取。 |
||
info |
输入 |
描述SrcTensor和maskTensor的shape信息。SelectWithBytesMaskShapeInfo类型,定义如下:
|
返回值
无
支持的型号
约束说明
- 为了节省地址空间,开发者可以定义一个Tensor,供源操作数与目的操作数同时使用(即地址重叠),相关约束如下:源操作数与目的操作数之间要求100%完全重叠,不支持部分重叠。
- 操作数地址偏移对齐要求请参见通用约束。
- maskTensor尾轴元素个数和源操作数尾轴元素个数不同的情况下, maskTensor的数据有可能被接口改写。
调用示例
完整调用样例请参考SelectWithBytesMaskCustom算子样例。
1 2 3 4 5 6 7 8 9 | AscendC::SelectWithBytesMaskShapeInfo info; srcLocal1 = inQueueX1.DeQue<srcType>(); maskLocal = maskQueue.DeQue<maskType>(); AscendC::LocalTensor<uint8_t> tmpBuffer = sharedTmpBuffer.Get<uint8_t>(); dstLocal = outQueue.AllocTensor<srcType>(); AscendC::SelectWithBytesMask(dstLocal, srcLocal1, scalar, maskLocal, tmpBuffer, info); outQueue.EnQue<srcType>(dstLocal); maskQueue.FreeTensor(maskLocal); inQueueX1.FreeTensor(srcLocal1); |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 | 输入数据(src0Local): [-84.6 -24.38 30.97 -30.25 22.28 -92.56 90.44 -58.72 -86.56 5.74 6.754 -86.3 -96.7 -37.38 -81.9 46.9 -99.4 94.2 -41.78 -60.3 -14.43 78.6 8.93 -65.2 79.94 -46.88 4.516 20.03 -25.56 24.73 0.3223 21.98 -87.4 -93.9 46.22 -69.9 90.8 -24.17 -96.2 -91. 90.44 9.766 68.25 -57.78 -75.44 -8.86 -91.56 21.6 76. 82.1 -78. -23.75 92. -66.44 75. 94.9 2.62 -90.9 15.945 38.16 50.84 96.94 -59.38 44.22 ] 输入数据(src1): [35.6] 输入数据(maskLocal): [False True False False True True False True True False False True False True False True True False False False True True True True True False True False True True True True False False True False True False True False True False True False True True True False True False True False True False True True True False False False True False True True ] 输出数据(dstLocal): [-84.6 35.6 30.97 -30.25 35.6 35.6 90.44 35.6 35.6 5.74 6.754 35.6 -96.7 35.6 -81.9 35.6 35.6 94.2 -41.78 -60.3 35.6 35.6 35.6 35.6 35.6 -46.88 35.6 20.03 35.6 35.6 35.6 35.6 -87.4 -93.9 35.6 -69.9 35.6 -24.17 35.6 -91. 35.6 9.766 35.6 -57.78 35.6 35.6 35.6 21.6 35.6 82.1 35.6 -23.75 35.6 -66.44 35.6 35.6 35.6 -90.9 15.945 38.16 35.6 96.94 35.6 35.6 ] |