SelectWithBytesMask

Function Usage

Given two source operands src0 and src1, selects elements based on the values (non-bit) of corresponding positions of maskTensor to obtain the destination operand dst. The selection rule is as follows: When the value of Mask is 0, the elements are selected from src0. Otherwise, the elements are selected from src1.

This API supports multi-dimensional shapes. This requires that the number of elements on the front axis (non-last axis) of maskTensor be the same as that of the source operand tensor, and the number of elements on the last axis of maskTensor be greater than or equal to that of the source operand. The excess part of maskTensor is deprecated and not used in computation.

  • The last axis of maskTensor must be 32-byte aligned and the number of elements must be a multiple of 16.
  • The last axis of the source operand tensor must be 32-byte aligned.

As shown in the figure below, the source operand src0 is a tensor with a shape of (2, 16) and a data type of half, and its last axis length is 32-byte aligned. The source operand src1 is a scalar with a data type of half. The data type of maskTensor is bool. To meet the alignment requirement, its shape must be (2, 32). Only the masks in blue take effect, and the masks in gray are not involved in computation. This figure shows how the destination operand dstTensor is output.

Principles

The figure below illustrates the internal algorithm block diagram of SelectWithBytesMask high-level APIs, taking the float type, ND format, source input tensor with a shape of [m, k1], and mask tensor with a shape of [m, k2] as examples.

Figure 1 SelectWithBytesMask algorithm block diagram

The computation process is divided into the following steps, all of which are performed on vectors:

  1. GatherMask step: If k1 and k2 are not equal, perform reduce computation on the input mask[m, k2] by using GatherMask based on shape[m, k1] of src. In this way, the excess part on the k axis of the mask is deprecated and the shape is converted to [m, k1].
  2. Cast step: Cast the mask result in the previous step to the half type.
  3. Compare step: Use the Compare API to compare the mask result in the previous step with 0 to obtain the cmpmask result.
  4. Select step: Select the value at the corresponding position of srcTensor or a scalar value based on the cmpmask result, and output it.

Prototype

  • src0 is srcTensor (tensor), and src1 is srcScalar (scalar).
    1
    2
    template <typename T, typename U, bool isReuseMask = true>
    __aicore__ inline void SelectWithBytesMask(const LocalTensor<T> &dst, const LocalTensor<T> &src0, T src1, const LocalTensor<U> &mask, const LocalTensor<uint8_t> &sharedTmpBuffer, const SelectWithBytesMaskShapeInfo &info)
    
  • src0 is srcScalar (scalar), and src1 is srcTensor (tensor).
    1
    2
    template <typename T, typename U, bool isReuseMask = true>
    __aicore__ inline void SelectWithBytesMask(const LocalTensor<T> &dst, T src0, const LocalTensor<T> &src1, const LocalTensor<U> &mask, const LocalTensor<uint8_t> &sharedTmpBuffer, const SelectWithBytesMaskShapeInfo &info)
    

This API requires extra temporary space to store intermediate variables during computation. The temporary space needs to be allocated and passed through the sharedTmpBuffer input parameter by developers. To obtain the size of the temporary space (BufferSize) to be reserved, use the API provided in GetSelectWithBytesMaskMaxMinTmpSize.

Parameters

Table 1 Parameters in the template

Parameter

Description

T

Data type of the operand.

U

Data type of maskTensor.

isReuseMask

Whether maskTensor can be modified. The default value is True.

If this parameter is set to True, maskTensor may be modified only when the number of elements on the last axis of maskTensor is different from that of elements on the last axis of srcTensor. In other scenarios, maskTensor is not modified.

If this parameter is set to False, maskTensor is not modified in any scenario, but more temporary space may be required.

Table 2 API parameters

Parameter

Input/Output

Meaning

dst

Output

Destination operand.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

src0(srcTensor)

src1(srcTensor)

Input

Source operand. The last axis of the source operand tensor must be 32-byte aligned.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

src1(srcScalar)

src0(srcScalar)

Input

Source operand of the scalar type.

mask

Input

Mask tensor, which describes how to select a value between srcTensor and srcScalar. The last axis of maskTensor must be 32-byte aligned and the number of elements must be a multiple of 16.

The value is 0x00/0x01.

  • src0 is srcTensor (tensor), and src1 is srcScalar (scalar).

    If the value of mask is 0, select the corresponding value of srcTensor and store it in dstLocal. Otherwise, select the value of srcScalar and store it in dstLocal.

  • src0 is srcScalar (scalar), and src1 is srcTensor (tensor).

    If the value of mask is 0, select the value of srcScalar and store it in dstLocal. Otherwise, select the value of srcTensor and store it in dstLocal.

sharedTmpBuffer

Input

This API is used to calculate the temporary space. The required space size can be obtained based on GetSelectWithBytesMaskMaxMinTmpSize.

info

Input

Shape information of SrcTensor and maskTensor. The SelectWithBytesMaskShapeInfo type is defined as follows:

1
2
3
4
5
6
struct SelectWithBytesMaskShapeInfo {
__aicore__ SelectWithBytesMaskShapeInfo(){};
uint32_t firstAxis = 0;   // number of elements on the front axis of srcLocal/maskTensor
uint32_t srcLastAxis = 0; // Number of elements on the last axis of srcLocal
uint32_t maskLastAxis = 0;// Number of elements on the last axis of maskTensor
};
  • The number of elements on the front axis of srcTensor must be the same as that of maskTensor, which is firstAxis.
  • The following requirements must be met: firstAxis x srcLastAxis = srcTensor.GetSize(); firstAxis x maskLastAxis = maskTensor.GetSize()
  • If the number of elements on the last axis of maskTensor is greater than or equal to that of srcTensor, the extra part of maskTensor is deprecated and is not used in computation.

Returns

None

Availability

Precautions

  • To save the memory space, developers can define a tensor shared by the source and destination operands (by address overlapping). The general instruction restriction is that the source operand must completely overlap the destination operand.
  • For details about the alignment requirements of the operand address offset, see General Restrictions.
  • If the number of elements on the last axis of maskTensor is different from that of the source operand, the maskTensor data may be overwritten by the API.

Examples

This example shows only part of the code in the Compute process. To run the sample code, copy the code snippet and replace some code of the Compute function in Template Sample.
1
2
3
4
5
AscendC::SelectWithBytesMaskShapeInfo shapeInfo;
shapeInfo.firstAxis = 2;
shapeInfo.srcLastAxis = 32;
shapeInfo.maskLastAxis = 32;
AscendC::SelectWithBytesMask(dstLocal, srcLocal, src1, maskLocal, tmpTensor, shapeInfo);
Result example:
Input (src0Local):
[-84.6    -24.38    30.97   -30.25    22.28   -92.56    90.44   -58.72  -86.56     5.74     6.754  -86.3    -96.7    -37.38   -81.9     46.9
 -99.4     94.2    -41.78   -60.3    -14.43    78.6      8.93   -65.2    79.94   -46.88     4.516   20.03   -25.56    24.73     0.3223  21.98

 -87.4    -93.9     46.22   -69.9     90.8    -24.17   -96.2    -91.    90.44     9.766   68.25   -57.78   -75.44    -8.86   -91.56    21.6
  76.      82.1    -78.     -23.75    92.     -66.44    75.      94.9   2.62   -90.9     15.945   38.16    50.84    96.94   -59.38    44.22  ]
Input data (src1):
[35.6]
Input (maskLocal):
[False  True False False  True  True False  True  True False False  True False  True False  True  
 True   False False False  True  True  True  True   True False  True False  True  True  True  True 

 False False  True False  True False  True False  True False  True False  True  True  True False
 True False  True False  True False  True  True   True False False False  True False  True  True
]

Output (dstLocal):
[-84.6    35.6    30.97   -30.25   35.6    35.6    90.44   35.6  35.6    5.74    6.754   35.6   -96.7    35.6   -81.9    35.6
  35.6    94.2    -41.78  -60.3    35.6    35.6    35.6    35.6  35.6   -46.88   35.6    20.03   35.6    35.6    35.6    35.6
   
 -87.4   -93.9    35.6    -69.9    35.6   -24.17   35.6   -91.   35.6   9.766  35.6   -57.78   35.6     35.6    35.6    21.6
  35.6    82.1    35.6    -23.75   35.6   -66.44   35.6    35.6  35.6   -90.9    15.945  38.16   35.6    96.94   35.6    35.6  ]

Template Sample

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
#include "kernel_operator.h"

template <typename srcType, typename maskType>
class KernelSelect {
public:
    __aicore__ inline KernelSelect()
    {}
    __aicore__ inline void Init(GM_ADDR src1Gm, GM_ADDR maskGm, GM_ADDR dstGm, float scalarValue, uint32_t firstAxis,
        uint32_t srcAxis, uint32_t maskAxis, uint32_t tmpSize)
    {
        uint32_t srcSize = firstAxis * srcAxis;
        uint32_t maskSize = firstAxis * maskAxis;
        src1Global.SetGlobalBuffer(reinterpret_cast<__gm__ srcType *>(src1Gm), srcSize);
        mask_global.SetGlobalBuffer(reinterpret_cast<__gm__ maskType *>(maskGm), maskSize);
        dstGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ srcType *>(dstGm), srcSize);
        pipe.InitBuffer(inQueueX1, 1, srcSize * sizeof(srcType));
        pipe.InitBuffer(maskQueue, 1, maskSize * sizeof(maskType));
        pipe.InitBuffer(tmpQueue, 1, tmpSize);
        bufferSize = srcSize;
        scalar = static_cast<srcType>(scalarValue);
        maskBufferSize = maskSize;
        info.firstAxis = firstAxis;
        info.srcLastAxis = srcAxis;
        info.maskLastAxis = maskAxis;
    }
    __aicore__ inline void Process()
    {
        CopyIn();
        Compute();
        CopyOut();
    }

private:
    __aicore__ inline void CopyIn()
    {
        srcLocal1 = inQueueX1.AllocTensor<srcType>();
        AscendC::DataCopy(srcLocal1, src1Global, bufferSize);
        inQueueX1.EnQue(srcLocal1);
        AscendC::LocalTensor<maskType> maskLocal = maskQueue.AllocTensor<maskType>();
        AscendC::DataCopy(maskLocal, mask_global, maskBufferSize);
        maskQueue.EnQue(maskLocal);
    }
    __aicore__ inline void Compute()
    {
        srcLocal1 = inQueueX1.DeQue<srcType>();
        AscendC::LocalTensor<maskType> maskLocal = maskQueue.DeQue<maskType>();
        AscendC::LocalTensor<uint8_t> tmpLocal = tmpQueue.AllocTensor<uint8_t>();

        AscendC::SelectWithBytesMask(srcLocal1, srcLocal1, scalar, maskLocal, tmpLocal, info);
        // Reverse Select.
        // AscendC::SelectWithBytesMask(srcLocal1, scalar, srcLocal1, maskLocal, tmpLocal, info);
        // Do not reuse source.
        // AscendC::SelectWithBytesMask<srcType, maskType, false>(srcLocal1, srcLocal1, scalar, maskLocal, tmpLocal, info);
        maskQueue.FreeTensor(maskLocal);
        tmpQueue.FreeTensor(tmpLocal);
    }
    __aicore__ inline void CopyOut()
    {
        AscendC::DataCopy(dstGlobal, srcLocal1, bufferSize);
        inQueueX1.FreeTensor(srcLocal1);
    }

private:
    AscendC::GlobalTensor<srcType> src1Global;
    AscendC::GlobalTensor<srcType> dstGlobal;
    AscendC::GlobalTensor<maskType> mask_global;
    AscendC::TPipe pipe;
    AscendC::TQue<AscendC::QuePosition::VECIN, 1> inQueueX1;
    AscendC::TQue<AscendC::QuePosition::VECIN, 1> maskQueue;
    AscendC::TQue<AscendC::QuePosition::VECIN, 1> tmpQueue;
    AscendC::TQue<AscendC::QuePosition::VECOUT, 1> outQueue;
    AscendC::SelectWithBytesMaskShapeInfo info;
    AscendC::LocalTensor<srcType> srcLocal1;
    uint32_t bufferSize = 0;
    uint32_t maskBufferSize = 0;
    srcType scalar = 0.0f;
};

template <typename srcType, typename maskType>
__aicore__ void kernel_select_with_bytes_mask_operator(GM_ADDR src1Gm, GM_ADDR maskGm, GM_ADDR dstGm, float scalar,
    uint32_t firstAxis, uint32_t srcSize, uint32_t maskSize, uint32_t tmpSize)
{
    KernelSelect<srcType, maskType> op;
    op.Init(src1Gm, maskGm, dstGm, scalar, firstAxis, srcSize, maskSize, tmpSize);
    op.Process();
}

extern "C" __global__ __aicore__ void kernel_select_with_bytes_mask_operator(
    GM_ADDR src1Gm, GM_ADDR maskGm, GM_ADDR dstGm, GM_ADDR tiling)
{
    GET_TILING_DATA(tilingData, tiling);
    kernel_select_with_bytes_mask_operator<half, bool>(src1Gm,
        maskGm,
        dstGm,
        tilingData.scalarValue,
        tilingData.firstAxis,
        tilingData.srcSize,
        tilingData.maskSize,
        tilingData.tmpSize);
}