Select
Applicability
|
Product |
Supported |
|---|---|
|
|
√ |
|
|
√ |
|
|
x |
|
|
√ |
|
|
x |
|
|
x |
Function
Given two source operands src0 and src1, selects elements based on the values (non-bit) of corresponding positions of maskTensor to obtain the destination operand dst. The selection rule is as follows: When the value of Mask is 0, the elements are selected from src0. Otherwise, the elements are selected from src1.
This API supports multi-dimensional shapes. This requires that the number of elements on the front axis (non-last axis) of maskTensor be the same as that of the source operand tensor, and the number of elements on the last axis of maskTensor be greater than or equal to that of the source operand. The excess part of maskTensor is deprecated and not used in computation.
- The last axis of maskTensor must be 32-byte aligned and the number of elements must be a multiple of 16.
- The last axis of the source operand tensor must be 32-byte aligned.
As shown in the figure below, the source operand src0 is a tensor with a shape of (2, 16) and a data type of half, and its last axis length is 32-byte aligned. The source operand src1 is a scalar with a data type of half. The data type of maskTensor is bool. To meet the alignment requirement, its shape must be (2, 32). Only the masks in blue take effect, and the masks in gray are not involved in computation. This figure shows how the destination operand dstTensor is output.

Principles
The figure below illustrates the internal algorithm block diagram of SelectWithBytesMask high-level APIs, taking the float type, ND format, source input tensor with a shape of [m, k1], and mask tensor with a shape of [m, k2] as examples.
The computation process is divided into the following steps, all of which are performed on vectors:
- GatherMask step: If k1 and k2 are not equal, perform reduce computation on the input mask[m, k2] by using GatherMask based on shape[m, k1] of src. In this way, the excess part on the k axis of the mask is deprecated and the shape is converted to [m, k1].
- Cast step: Cast the mask result in the previous step to the half type.
- Compare step: Use the Compare API to compare the mask result in the previous step with 0 to obtain the cmpmask result.
- Select step: Select the value at the corresponding position of srcTensor or a scalar value based on the cmpmask result, and output it.
Prototype
- src0 is srcTensor (tensor), and src1 is srcScalar (scalar).
1 2
template <typename T, typename U, bool isReuseMask = true> __aicore__ inline void Select(const LocalTensor<T>& dst, const LocalTensor<T>& src0, T src1, const LocalTensor<U>& mask, const LocalTensor<uint8_t>& sharedTmpBuffer, const SelectWithBytesMaskShapeInfo& info)
- src0 is srcScalar (scalar), and src1 is srcTensor (tensor).
1 2
template <typename T, typename U, bool isReuseMask = true> __aicore__ inline void Select(const LocalTensor<T>& dst, T src0, const LocalTensor<T>& src1, const LocalTensor<U>& mask, const LocalTensor<uint8_t>& sharedTmpBuffer, const SelectWithBytesMaskShapeInfo& info)
This API requires extra temporary space to store intermediate variables during computation. The temporary space needs to be allocated and passed through the sharedTmpBuffer input parameter by developers. To obtain the temporary space (BufferSize) to be reserved, use the API provided in GetReduceAllMaxMinTmpSize.
Parameters
|
Parameter |
Description |
|---|---|
|
T |
Data type of the operand. For the For the For the |
|
U |
Data type of maskTensor. |
|
isReuseMask |
Whether maskTensor can be modified. Defaults to true. If this parameter is set to True, maskTensor may be modified only when the number of elements on the last axis of maskTensor is different from that of elements on the last axis of srcTensor. In other scenarios, maskTensor is not modified. If the value is false, the mask tensor will not be modified in any scenario, but more temporary space may be required. |
|
Parameter |
Input/Output |
Definition |
||
|---|---|---|---|---|
|
dst |
Output |
Destination operand. The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT. |
||
|
src0(srcTensor) src1(srcTensor) |
Input |
Source operand. The last axis of the source operand tensor must be 32-byte aligned. The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT. |
||
|
src1(srcScalar) src0(srcScalar) |
Input |
Source operand. of the scalar type. |
||
|
mask |
Input |
Mask tensor, which describes how to select a value between srcTensor and srcScalar. The last axis of maskTensor must be 32-byte aligned and the number of elements must be a multiple of 16.
|
||
|
sharedTmpBuffer |
Input |
This API is used to calculate the temporary space. The required space size can be obtained based on GetSelectMaxMinTmpSize. |
||
|
info |
Input |
Shape information of SrcTensor and maskTensor. The SelectWithBytesMaskShapeInfo type is defined as follows:
Note:
|
Returns
None
Restrictions
- The source operand and destination operand can be reused.
- For details about the operand address alignment requirements, see General Address Alignment Restrictions.
- If the number of elements on the last axis of maskTensor is different from that of the source operand, the maskTensor data may be overwritten by the API.
Example
1 2 3 4 5 6 7 8 9 |
AscendC::SelectWithBytesMaskShapeInfo info; srcLocal1 = inQueueX1.DeQue<srcType>(); maskLocal = maskQueue.DeQue<maskType>(); AscendC::LocalTensor<uint8_t> tmpBuffer = sharedTmpBuffer.Get<uint8_t>(); dstLocal = outQueue.AllocTensor<srcType>(); AscendC::Select(dstLocal, srcLocal1, scalar, maskLocal, tmpBuffer, info); outQueue.EnQue<srcType>(dstLocal); maskQueue.FreeTensor(maskLocal); inQueueX1.FreeTensor(srcLocal1); |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
Input data (srcLocal1): [-84.6 -24.38 30.97 -30.25 22.28 -92.56 90.44 -58.72 -86.56 5.74 6.754 -86.3 -96.7 -37.38 -81.9 46.9 -99.4 94.2 -41.78 -60.3 -14.43 78.6 8.93 -65.2 79.94 -46.88 4.516 20.03 -25.56 24.73 0.3223 21.98 -87.4 -93.9 46.22 -69.9 90.8 -24.17 -96.2 -91. 90.44 9.766 68.25 -57.78 -75.44 -8.86 -91.56 21.6 76. 82.1 -78. -23.75 92. -66.44 75. 94.9 2.62 -90.9 15.945 38.16 50.84 96.94 -59.38 44.22 ] Input data scalar: [35.6] Input (maskLocal): [False True False False True True False True True False False True False True False True True False False False True True True True True False True False True True True True False False True False True False True False True False True False True True True False True False True False True False True True True False False False True False True True ] Output (dstLocal): [-84.6 35.6 30.97 -30.25 35.6 35.6 90.44 35.6 35.6 5.74 6.754 35.6 -96.7 35.6 -81.9 35.6 35.6 94.2 -41.78 -60.3 35.6 35.6 35.6 35.6 35.6 -46.88 35.6 20.03 35.6 35.6 35.6 35.6 -87.4 -93.9 35.6 -69.9 35.6 -24.17 35.6 -91. 35.6 9.766 35.6 -57.78 35.6 35.6 35.6 21.6 35.6 82.1 35.6 -23.75 35.6 -66.44 35.6 35.6 35.6 -90.9 15.945 38.16 35.6 96.94 35.6 35.6 ] |