Select

Applicability

Product

Supported

Atlas A3 training products / Atlas A3 inference products

Atlas A2 training products / Atlas A2 inference products

Atlas 200I/500 A2 inference products

x

Atlas inference product 's AI Core

Atlas inference product 's Vector Core

x

Atlas training products

x

Function

Given two source operands src0 and src1, selects elements based on the values (non-bit) of corresponding positions of maskTensor to obtain the destination operand dst. The selection rule is as follows: When the value of Mask is 0, the elements are selected from src0. Otherwise, the elements are selected from src1.

This API supports multi-dimensional shapes. This requires that the number of elements on the front axis (non-last axis) of maskTensor be the same as that of the source operand tensor, and the number of elements on the last axis of maskTensor be greater than or equal to that of the source operand. The excess part of maskTensor is deprecated and not used in computation.

  • The last axis of maskTensor must be 32-byte aligned and the number of elements must be a multiple of 16.
  • The last axis of the source operand tensor must be 32-byte aligned.

As shown in the figure below, the source operand src0 is a tensor with a shape of (2, 16) and a data type of half, and its last axis length is 32-byte aligned. The source operand src1 is a scalar with a data type of half. The data type of maskTensor is bool. To meet the alignment requirement, its shape must be (2, 32). Only the masks in blue take effect, and the masks in gray are not involved in computation. This figure shows how the destination operand dstTensor is output.

Principles

The figure below illustrates the internal algorithm block diagram of SelectWithBytesMask high-level APIs, taking the float type, ND format, source input tensor with a shape of [m, k1], and mask tensor with a shape of [m, k2] as examples.

Figure 1 Select algorithm block diagram

The computation process is divided into the following steps, all of which are performed on vectors:

  1. GatherMask step: If k1 and k2 are not equal, perform reduce computation on the input mask[m, k2] by using GatherMask based on shape[m, k1] of src. In this way, the excess part on the k axis of the mask is deprecated and the shape is converted to [m, k1].
  2. Cast step: Cast the mask result in the previous step to the half type.
  3. Compare step: Use the Compare API to compare the mask result in the previous step with 0 to obtain the cmpmask result.
  4. Select step: Select the value at the corresponding position of srcTensor or a scalar value based on the cmpmask result, and output it.

Prototype

  • src0 is srcTensor (tensor), and src1 is srcScalar (scalar).
    1
    2
    template <typename T, typename U, bool isReuseMask = true>
    __aicore__ inline void Select(const LocalTensor<T>& dst, const LocalTensor<T>& src0, T src1, const LocalTensor<U>& mask, const LocalTensor<uint8_t>& sharedTmpBuffer, const SelectWithBytesMaskShapeInfo& info)
    
  • src0 is srcScalar (scalar), and src1 is srcTensor (tensor).
    1
    2
    template <typename T, typename U, bool isReuseMask = true>
    __aicore__ inline void Select(const LocalTensor<T>& dst, T src0, const LocalTensor<T>& src1, const LocalTensor<U>& mask, const LocalTensor<uint8_t>& sharedTmpBuffer, const SelectWithBytesMaskShapeInfo& info)
    

This API requires extra temporary space to store intermediate variables during computation. The temporary space needs to be allocated and passed through the sharedTmpBuffer input parameter by developers. To obtain the temporary space (BufferSize) to be reserved, use the API provided in GetReduceAllMaxMinTmpSize.

Parameters

Table 1 Template parameters

Parameter

Description

T

Data type of the operand.

For the Atlas A3 training products / Atlas A3 inference products , the supported data types are half and float.

For the Atlas A2 training products / Atlas A2 inference products , the supported data types are half and float.

For the Atlas inference product 's AI Core, the supported data types are half and float.

U

Data type of maskTensor.

Atlas A3 training products / Atlas A3 inference products : The supported data types are bool, int8_t, uint8_t, int16_t, uint16_t, int32_t and uint32_t.

Atlas A2 training products / Atlas A2 inference products : The supported data types are bool, int8_t, uint8_t, int16_t, uint16_t, int32_t and uint32_t.

Atlas inference product 's AI Core: The supported data types are bool, int8_t, uint8_t, int16_t, uint16_t, int32_t and uint32_t.

isReuseMask

Whether maskTensor can be modified. Defaults to true.

If this parameter is set to True, maskTensor may be modified only when the number of elements on the last axis of maskTensor is different from that of elements on the last axis of srcTensor. In other scenarios, maskTensor is not modified.

If the value is false, the mask tensor will not be modified in any scenario, but more temporary space may be required.

Table 2 API parameters

Parameter

Input/Output

Definition

dst

Output

Destination operand.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

src0(srcTensor)

src1(srcTensor)

Input

Source operand. The last axis of the source operand tensor must be 32-byte aligned.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

src1(srcScalar)

src0(srcScalar)

Input

Source operand. of the scalar type.

mask

Input

Mask tensor, which describes how to select a value between srcTensor and srcScalar. The last axis of maskTensor must be 32-byte aligned and the number of elements must be a multiple of 16.

  • src0 is srcTensor (tensor), and src1 is srcScalar (scalar).

    If the value of mask is 0, select the corresponding value of srcTensor and store it in dstLocal. Otherwise, select the value of srcScalar and store it in dstLocal.

  • src0 is srcScalar (scalar), and src1 is srcTensor (tensor).

    If the value of mask is 0, select the value of srcScalar and store it in dstLocal. Otherwise, select the value of srcTensor and store it in dstLocal.

sharedTmpBuffer

Input

This API is used to calculate the temporary space. The required space size can be obtained based on GetSelectMaxMinTmpSize.

info

Input

Shape information of SrcTensor and maskTensor. The SelectWithBytesMaskShapeInfo type is defined as follows:

1
2
3
4
5
6
struct SelectWithBytesMaskShapeInfo {
__aicore__ SelectShapeInfo(){};
uint32_t firstAxis = 0;    
uint32_t srcLastAxis = 0; 
uint32_t maskLastAxis = 0;
};
  • firstAxis: number of elements in the first axis of srcLocal or maskTensor.
  • srcLastAxis: number of elements in the last axis of srcLocal.
  • maskLastAxis: number of elements in the last axis of maskTensor.

Note:

  • The number of elements on the front axis of srcTensor must be the same as that of maskTensor, which is firstAxis.
  • The following requirements must be met: firstAxis x srcLastAxis = srcTensor.GetSize(); firstAxis x maskLastAxis = maskTensor.GetSize()
  • The number of elements in the last axis of maskTensor must be greater than or equal to that of elements in the last axis of srcTensor. During computation, the redundant part of maskTensor is discarded and does not participate in the computation.

Returns

None

Restrictions

  • The source operand and destination operand can be reused.
  • For details about the operand address alignment requirements, see General Address Alignment Restrictions.
  • If the number of elements on the last axis of maskTensor is different from that of the source operand, the maskTensor data may be overwritten by the API.

Example

1
2
3
4
5
6
7
8
9
AscendC::SelectWithBytesMaskShapeInfo info;
srcLocal1 = inQueueX1.DeQue<srcType>();
maskLocal = maskQueue.DeQue<maskType>();
AscendC::LocalTensor<uint8_t> tmpBuffer = sharedTmpBuffer.Get<uint8_t>();
dstLocal = outQueue.AllocTensor<srcType>();
AscendC::Select(dstLocal, srcLocal1, scalar, maskLocal, tmpBuffer, info);
outQueue.EnQue<srcType>(dstLocal);
maskQueue.FreeTensor(maskLocal);
inQueueX1.FreeTensor(srcLocal1);
Result example:
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
Input data (srcLocal1):
[-84.6    -24.38    30.97   -30.25    22.28   -92.56    90.44   -58.72  -86.56     5.74     6.754  -86.3    -96.7    -37.38   -81.9     46.9
 -99.4     94.2    -41.78   -60.3    -14.43    78.6      8.93   -65.2    79.94   -46.88     4.516   20.03   -25.56    24.73     0.3223  21.98

 -87.4    -93.9     46.22   -69.9     90.8    -24.17   -96.2    -91.    90.44     9.766   68.25   -57.78   -75.44    -8.86   -91.56    21.6
  76.      82.1    -78.     -23.75    92.     -66.44    75.      94.9   2.62   -90.9     15.945   38.16    50.84    96.94   -59.38    44.22  ]
Input data scalar:
[35.6]
Input (maskLocal):
[False  True False False  True  True False  True  True False False  True False  True False  True  
 True   False False False  True  True  True  True   True False  True False  True  True  True  True 

 False False  True False  True False  True False  True False  True False  True  True  True False
 True False  True False  True False  True  True   True False False False  True False  True  True
]

Output (dstLocal):
[-84.6    35.6    30.97   -30.25   35.6    35.6    90.44   35.6  35.6    5.74    6.754   35.6   -96.7    35.6   -81.9    35.6
  35.6    94.2    -41.78  -60.3    35.6    35.6    35.6    35.6  35.6   -46.88   35.6    20.03   35.6    35.6    35.6    35.6
   
 -87.4   -93.9    35.6    -69.9    35.6   -24.17   35.6   -91.   35.6   9.766  35.6   -57.78   35.6     35.6    35.6    21.6
  35.6    82.1    35.6    -23.75   35.6   -66.44   35.6    35.6  35.6   -90.9    15.945  38.16   35.6    96.94   35.6    35.6  ]