MrgSort

Applicability

Product

Supported

Atlas A3 training products/Atlas A3 inference products

Atlas A2 training products/Atlas A2 inference products

Atlas 200I/500 A2 inference products

x

Atlas inference product's AI Core

Atlas inference product's Vector Core

x

Atlas training products

x

Function

Merges at most four sorted lists into one. The results are sorted in descending order of the score fields. The layout modes are as follows:

For the Atlas A3 training products/Atlas A3 inference products, mode 1 is used.

For the Atlas A2 training products/Atlas A2 inference products, mode 1 is used.

For the Atlas inference product's AI Core, method 2 is used.

  • Layout mode 1:
    Generally, the data to be processed by MrgSort is the data processed by Sort, or the output of Sort. The list structure is as follows:
    • When the data type is float, each structure occupies 8 bytes.

    • When the data type is half, each structure occupies 8 bytes, but 2 bytes in the middle are reserved.

  • Layout mode 2: region proposal

    The input and output data are region proposals. For details, see mode 2 in Sort.

Prototype

1
2
template <typename T, bool isExhaustedSuspension = false>
__aicore__ inline void MrgSort(const LocalTensor<T> &dst, const MrgSortSrcList<T> &sortList, const uint16_t elementCountList[4], uint32_t sortedNum[4], uint16_t validBit, const int32_t repeatTime)

Parameters

Table 1 Template parameters

Parameter

Description

T

Data type of an operand.

For the Atlas A3 training products/Atlas A3 inference products, the supported data types are half and float.

For the Atlas A2 training products/Atlas A2 inference products, the supported data types are half and float.

For the Atlas inference product's AI Core, the supported data types are half and float.

isExhaustedSuspension

Whether to stop merging after a list is exhausted (that is, all operations in the list have been sorted to the destination operand list). The type is bool. The options are as follows:

  • false: The merging stops only when all lists are exhausted.
  • true: The merging stops after a list is exhausted.

The default value is false.

Table 2 API parameters

Parameter

Input/Output

Description

dst

Output

Destination operand, which stores sorted data.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

sortList

Input

Source operand of the MrgSortSrcList struct type, which can contain two to four sorted lists. For details, see Table 3. The lists to be merged are passed to MrgSortSrcList.

1
2
3
4
5
6
7
template <typename T>
struct MrgSortSrcList {
    LocalTensor<T> src1;
    LocalTensor<T> src2;
    LocalTensor<T> src3; // If the number of lists to be merged is less than 3, the tensor can be empty.
    LocalTensor<T> src4; // If the number of lists to be merged is less than 4, the tensor can be empty.
};

elementCountList

Input

Lengths of the four source queues (sorting mode 1: number of 8-byte structures; sorting mode 2: number of 16 x sizeof(T)-byte structures). The data type is an array of uint16_t with a length of 4. Theoretically, the value range of each element is [0, 4095], but the value cannot exceed the storage space of the UB.

sortedNum

Output

Number of sorted elements in each list when merging is stopped in exhaustion mode (that is, when isExhaustedSuspension is true).

validBit

Input

Number of valid lists. The values are as follows:
  • 0b11: The first two lists are valid.
  • 0b111: The first three lists are valid.
  • 0b1111: All the four lists are valid.

repeatTime

Input

Number of iteration repeats. The total length of the four lists is skipped for the source and destination operands in each iteration. Value range: repeatTime ∈ [1, 255].

The repeatTime parameter takes effect only when all of the following conditions are met:
  • srcLocal contains four lists and validBit is 15.
  • The lengths of the four source lists are the same.
  • The four source lists are stored consecutively.
  • isExhaustedSuspension is false.
Table 3 MrgSortSrcList parameters

Parameter

Input/Output

Description

src1

Input

Source operand, which stores the first sorted list.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

The source operand must have the same data type as the destination operand.

For the Atlas A3 training products/Atlas A3 inference products, the supported data types are half and float.

For the Atlas A2 training products/Atlas A2 inference products, the supported data types are half and float.

For the Atlas inference product's AI Core, the supported data types are half and float.

src2

Input

Source operand, which stores the second sorted list.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

The source operand must have the same data type as the destination operand.

For the Atlas A3 training products/Atlas A3 inference products, the supported data types are half and float.

For the Atlas A2 training products/Atlas A2 inference products, the supported data types are half and float.

For the Atlas inference product's AI Core, the supported data types are half and float.

src3

Input

Source operand, which stores the third sorted list.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

The source operand must have the same data type as the destination operand.

For the Atlas A3 training products/Atlas A3 inference products, the supported data types are half and float.

For the Atlas A2 training products/Atlas A2 inference products, the supported data types are half and float.

For the Atlas inference product's AI Core, the supported data types are half and float.

src4

Input

Source operand, which stores the fourth sorted list.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

The source operand must have the same data type as the destination operand.

For the Atlas A3 training products/Atlas A3 inference products, the supported data types are half and float.

For the Atlas A2 training products/Atlas A2 inference products, the supported data types are half and float.

For the Atlas inference product's AI Core, the supported data types are half and float.

Returns

None

Constraints

  • When score[i] is the same as score[j], if i > j, score[j] is selected first. That is, the index sequence is the same as the input sequence.
  • Data within each iteration is sorted, but data among different iterations is not sorted.
  • For details about the operand address alignment requirements, see General Address Alignment Restrictions.

Example

  • Processing 128 pieces of half-type data

    This example applies to:

    Atlas A2 training products/Atlas A2 inference products

    Atlas A3 training products/Atlas A3 inference products

      1
      2
      3
      4
      5
      6
      7
      8
      9
     10
     11
     12
     13
     14
     15
     16
     17
     18
     19
     20
     21
     22
     23
     24
     25
     26
     27
     28
     29
     30
     31
     32
     33
     34
     35
     36
     37
     38
     39
     40
     41
     42
     43
     44
     45
     46
     47
     48
     49
     50
     51
     52
     53
     54
     55
     56
     57
     58
     59
     60
     61
     62
     63
     64
     65
     66
     67
     68
     69
     70
     71
     72
     73
     74
     75
     76
     77
     78
     79
     80
     81
     82
     83
     84
     85
     86
     87
     88
     89
     90
     91
     92
     93
     94
     95
     96
     97
     98
     99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    #include "kernel_operator.h"
    template <typename T>
    class FullSort
    {
    public:
        __aicore__ inline FullSort() {}
        __aicore__ inline void Init(__gm__ uint8_t *srcValueGm, __gm__ uint8_t *srcIndexGm, __gm__ uint8_t *dstValueGm, __gm__ uint8_t *dstIndexGm)
        {
            concatRepeatTimes = elementCount / 16;
            inBufferSize = elementCount * sizeof(uint32_t);
            outBufferSize = elementCount * sizeof(uint32_t);
            calcBufferSize = elementCount * 8;
            tmpBufferSize = elementCount * 8;
            sortedLocalSize = elementCount * 4;
            sortRepeatTimes = elementCount / 32;
            extractRepeatTimes = elementCount / 32;
            sortTmpLocalSize = elementCount * 4;
            valueGlobal.SetGlobalBuffer((__gm__ T *)srcValueGm);
            indexGlobal.SetGlobalBuffer((__gm__ uint32_t *)srcIndexGm);
            dstValueGlobal.SetGlobalBuffer((__gm__ T *)dstValueGm);
            dstIndexGlobal.SetGlobalBuffer((__gm__ uint32_t *)dstIndexGm);
            pipe.InitBuffer(queIn, 2, inBufferSize);
            pipe.InitBuffer(queOut, 2, outBufferSize);
            pipe.InitBuffer(queCalc, 1, calcBufferSize * sizeof(T));
            pipe.InitBuffer(queTmp, 2, tmpBufferSize * sizeof(T));
        }
        __aicore__ inline void Process()
        {
            CopyIn();
            Compute();
            CopyOut();
        }
    
    private:
        __aicore__ inline void CopyIn()
        {
            AscendC::LocalTensor<T> valueLocal = queIn.AllocTensor<T>();
            AscendC::DataCopy(valueLocal, valueGlobal, elementCount);
            queIn.EnQue(valueLocal);
            AscendC::LocalTensor<uint32_t> indexLocal = queIn.AllocTensor<uint32_t>();
            AscendC::DataCopy(indexLocal, indexGlobal, elementCount);
            queIn.EnQue(indexLocal);
        }
        __aicore__ inline void Compute()
        {
            AscendC::LocalTensor<T> valueLocal = queIn.DeQue<T>();
            AscendC::LocalTensor<uint32_t> indexLocal = queIn.DeQue<uint32_t>();
            AscendC::LocalTensor<T> sortedLocal = queCalc.AllocTensor<T>();
            AscendC::LocalTensor<T> concatTmpLocal = queTmp.AllocTensor<T>();
            AscendC::LocalTensor<T> sortTmpLocal = queTmp.AllocTensor<T>();
            AscendC::LocalTensor<T> dstValueLocal = queOut.AllocTensor<T>();
            AscendC::LocalTensor<uint32_t> dstIndexLocal = queOut.AllocTensor<uint32_t>();
            AscendC::LocalTensor<T> concatLocal;
    
            AscendC::Concat(concatLocal, valueLocal, concatTmpLocal, concatRepeatTimes);
            AscendC::Sort<T, false>(sortedLocal, concatLocal, indexLocal, sortTmpLocal, sortRepeatTimes);
            uint32_t singleMergeTmpElementCount = elementCount / 4;
            uint32_t baseOffset = AscendC::GetSortOffset<T>(singleMergeTmpElementCount);
            AscendC::MrgSortSrcList sortList = AscendC::MrgSortSrcList(sortedLocal[0], sortedLocal[baseOffset], sortedLocal[2 * baseOffset], sortedLocal[3 * baseOffset]);
            uint16_t singleDataSize = elementCount / 4;
            const uint16_t elementCountList[4] = {singleDataSize, singleDataSize, singleDataSize, singleDataSize};
            uint32_t sortedNum[4];
            AscendC::MrgSort<T, false>(sortTmpLocal, sortList, elementCountList, sortedNum, 0b1111, 1);
            AscendC::Extract(dstValueLocal, dstIndexLocal, sortTmpLocal, extractRepeatTimes);
    
            queTmp.FreeTensor(concatTmpLocal);
            queTmp.FreeTensor(sortTmpLocal);
            queIn.FreeTensor(valueLocal);
            queIn.FreeTensor(indexLocal);
            queCalc.FreeTensor(sortedLocal);
            queOut.EnQue(dstValueLocal);
            queOut.EnQue(dstIndexLocal);
        }
        __aicore__ inline void CopyOut()
        {
            AscendC::LocalTensor<T> dstValueLocal = queOut.DeQue<T>();
            AscendC::LocalTensor<uint32_t> dstIndexLocal = queOut.DeQue<uint32_t>();
            AscendC::DataCopy(dstValueGlobal, dstValueLocal, elementCount);
            AscendC::DataCopy(dstIndexGlobal, dstIndexLocal, elementCount);
            queOut.FreeTensor(dstValueLocal);
            queOut.FreeTensor(dstIndexLocal);
        }
    
    private:
        AscendC::TPipe pipe;
        AscendC::TQue<AscendC::TPosition::VECIN, 2> queIn;
        AscendC::TQue<AscendC::TPosition::VECOUT, 2> queOut;
        AscendC::TQue<AscendC::TPosition::VECIN, 2> queTmp;
        AscendC::TQue<AscendC::TPosition::VECIN, 1> queCalc;
        AscendC::GlobalTensor<T> valueGlobal;
        AscendC::GlobalTensor<uint32_t> indexGlobal;
        AscendC::GlobalTensor<T> dstValueGlobal;
        AscendC::GlobalTensor<uint32_t> dstIndexGlobal;
        uint32_t elementCount = 128;
        uint32_t concatRepeatTimes;
        uint32_t inBufferSize;
        uint32_t outBufferSize;
        uint32_t calcBufferSize;
        uint32_t tmpBufferSize;
        uint32_t sortedLocalSize;
        uint32_t sortTmpLocalSize;
        uint32_t sortRepeatTimes;
        uint32_t extractRepeatTimes;
    };
    
    extern "C" __global__ __aicore__ void sort_operator(__gm__ uint8_t *src0Gm, __gm__ uint8_t *src1Gm, __gm__ uint8_t *dst0Gm, __gm__ uint8_t *dst1Gm)
    {
        FullSort<half> op;
        op.Init(src0Gm, src1Gm, dst0Gm, dst1Gm);
        op.Process();
    }
    
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    Result example:
    Input (srcValueGm): 128 pieces of float-type data
    [31 30 29 ... 2 1 0
     63 62 61 ... 34 33 32
     95 94 93 ... 66 65 64
     127 126 125 ... 98 97 96]
    Input (srcIndexGm):
    [31 30 29 ... 2 1 0
     63 62 61 ... 34 33 32
     95 94 93 ... 66 65 64
     127 126 125 ... 98 97 96]
    Output (dstValueGm):
    [127 126 125 ... 2 1 0]
    Output (dstIndexGm):
    [127 126 125 ... 2 1 0]
    
  • Processing 64 pieces of half-type data

    This example applies to:

    Atlas inference product's AI Core

      1
      2
      3
      4
      5
      6
      7
      8
      9
     10
     11
     12
     13
     14
     15
     16
     17
     18
     19
     20
     21
     22
     23
     24
     25
     26
     27
     28
     29
     30
     31
     32
     33
     34
     35
     36
     37
     38
     39
     40
     41
     42
     43
     44
     45
     46
     47
     48
     49
     50
     51
     52
     53
     54
     55
     56
     57
     58
     59
     60
     61
     62
     63
     64
     65
     66
     67
     68
     69
     70
     71
     72
     73
     74
     75
     76
     77
     78
     79
     80
     81
     82
     83
     84
     85
     86
     87
     88
     89
     90
     91
     92
     93
     94
     95
     96
     97
     98
     99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    #include "kernel_operator.h"
    
    template <typename T>
    class FullSort
    {
    public:
        __aicore__ inline FullSort() {}
        __aicore__ inline void Init(__gm__ uint8_t *srcValueGm, __gm__ uint8_t *srcIndexGm, __gm__ uint8_t *dstValueGm, __gm__ uint8_t *dstIndexGm)
        {
            concatRepeatTimes = elementCount / 16;
            inBufferSize = elementCount * sizeof(uint32_t);
            outBufferSize = elementCount * sizeof(uint32_t);
            calcBufferSize = elementCount * 8;
            tmpBufferSize = elementCount * 8;
            sortedLocalSize = elementCount * 8 * sizeof(T);
            sortRepeatTimes = elementCount / 16;
            extractRepeatTimes = elementCount / 16;
            sortTmpLocalSize = elementCount * 8 * sizeof(T);
            valueGlobal.SetGlobalBuffer((__gm__ T *)srcValueGm);
            indexGlobal.SetGlobalBuffer((__gm__ uint32_t *)srcIndexGm);
            dstValueGlobal.SetGlobalBuffer((__gm__ T *)dstValueGm);
            dstIndexGlobal.SetGlobalBuffer((__gm__ uint32_t *)dstIndexGm);
            pipe.InitBuffer(queIn, 2, inBufferSize);
            pipe.InitBuffer(queOut, 2, outBufferSize);
            pipe.InitBuffer(queCalc, 1, calcBufferSize * sizeof(T));
            pipe.InitBuffer(queTmp, 2, tmpBufferSize * sizeof(T));
        }
        __aicore__ inline void Process()
        {
            CopyIn();
            Compute();
            CopyOut();
        }
    
    private:
        __aicore__ inline void CopyIn()
        {
            AscendC::LocalTensor<T> valueLocal = queIn.AllocTensor<T>();
            AscendC::DataCopy(valueLocal, valueGlobal, elementCount);
            queIn.EnQue(valueLocal);
    
            AscendC::LocalTensor<uint32_t> indexLocal = queIn.AllocTensor<uint32_t>();
            AscendC::DataCopy(indexLocal, indexGlobal, elementCount);
            queIn.EnQue(indexLocal);
        }
        __aicore__ inline void Compute()
        {
            AscendC::LocalTensor<T> valueLocal = queIn.DeQue<T>();
            AscendC::LocalTensor<uint32_t> indexLocal = queIn.DeQue<uint32_t>();
            AscendC::LocalTensor<T> sortedLocal = queCalc.AllocTensor<T>();
            AscendC::LocalTensor<T> concatTmpLocal = queTmp.AllocTensor<T>();
            AscendC::LocalTensor<T> sortTmpLocal = queTmp.AllocTensor<T>();
            AscendC::LocalTensor<T> dstValueLocal = queOut.AllocTensor<T>();
            AscendC::LocalTensor<uint32_t> dstIndexLocal = queOut.AllocTensor<uint32_t>();
            AscendC::LocalTensor<T> concatLocal;
    
            AscendC::Concat(concatLocal, valueLocal, concatTmpLocal, concatRepeatTimes);
            AscendC::Sort<T, false>(sortedLocal, concatLocal, indexLocal, sortTmpLocal, sortRepeatTimes);
            uint32_t singleMergeTmpElementCount = elementCount / 4;
            uint32_t baseOffset = AscendC::GetSortOffset<T>(singleMergeTmpElementCount);
            AscendC::MrgSortSrcList sortList = AscendC::MrgSortSrcList(sortedLocal[0], sortedLocal[baseOffset], sortedLocal[2 * baseOffset], sortedLocal[3 * baseOffset]);
            uint16_t singleDataSize = elementCount / 4;
            const uint16_t elementCountList[4] = {singleDataSize, singleDataSize, singleDataSize, singleDataSize};
            uint32_t sortedNum[4];
    
            AscendC::MrgSort<T, false>(sortTmpLocal, sortList, elementCountList, sortedNum, 0b1111, 1);
            AscendC::Extract(dstValueLocal, dstIndexLocal, sortTmpLocal, extractRepeatTimes);
    
            queTmp.FreeTensor(concatTmpLocal);
            queTmp.FreeTensor(sortTmpLocal);
            queIn.FreeTensor(valueLocal);
            queIn.FreeTensor(indexLocal);
            queCalc.FreeTensor(sortedLocal);
            queOut.EnQue(dstValueLocal);
            queOut.EnQue(dstIndexLocal);
        }
        __aicore__ inline void CopyOut()
        {
            AscendC::LocalTensor<T> dstValueLocal = queOut.DeQue<T>();
            AscendC::LocalTensor<uint32_t> dstIndexLocal = queOut.DeQue<uint32_t>();
            AscendC::DataCopy(dstValueGlobal, dstValueLocal, elementCount);
            AscendC::DataCopy(dstIndexGlobal, dstIndexLocal, elementCount);
            queOut.FreeTensor(dstValueLocal);
            queOut.FreeTensor(dstIndexLocal);
        }
    
    private:
        AscendC::TPipe pipe;
        AscendC::TQue<AscendC::TPosition::VECIN, 2> queIn;
        AscendC::TQue<AscendC::TPosition::VECOUT, 2> queOut;
        AscendC::TQue<AscendC::TPosition::VECIN, 2> queTmp;
        AscendC::TQue<AscendC::TPosition::VECIN, 1> queCalc;
        AscendC::GlobalTensor<T> valueGlobal;
        AscendC::GlobalTensor<uint32_t> indexGlobal;
        AscendC::GlobalTensor<T> dstValueGlobal;
        AscendC::GlobalTensor<uint32_t> dstIndexGlobal;
        uint32_t elementCount = 64;
        uint32_t concatRepeatTimes;
        uint32_t inBufferSize;
        uint32_t outBufferSize;
        uint32_t calcBufferSize;
        uint32_t tmpBufferSize;
        uint32_t sortedLocalSize;
        uint32_t sortTmpLocalSize;
        uint32_t sortRepeatTimes;
        uint32_t extractRepeatTimes;
    };
    
    extern "C" __global__ __aicore__ void sort_operator(__gm__ uint8_t *src0Gm, __gm__ uint8_t *src1Gm, __gm__ uint8_t *dst0Gm, __gm__ uint8_t *dst1Gm)
    {
        FullSort<half> op;
        op.Init(src0Gm, src1Gm, dst0Gm, dst1Gm);
        op.Process();
    }
    
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    Result example:
    Input (srcValueGm): 128 pieces of float-type data
    [15 14 13 ... 2 1 0
     31 30 29 ... 18 17 16
     47 46 45 ... 34 33 32
     63 62 61 ... 50 49 48]
    Input (srcIndexGm):
    [15 14 13 ... 2 1 0
     31 30 29 ... 18 17 16
     47 46 45 ... 34 33 32
     63 62 61 ... 50 49 48]
    Output (dstValueGm):
    [63 62 61 ... 2 1 0]
    Output (dstIndexGm):
    [63 62 61 ... 2 1 0]