TopK

Function Usage

Obtains the first k maximum or minimum values of the last dimension and their corresponding indexes.

If the input is a vector, find the first k maximum or minimum values and their corresponding indexes within the vector. If the input is a matrix, calculate the first k maximum or minimum values and their corresponding indexes in each row along the last dimension. This API is limited to accepting inputs with a maximum of two dimensions and does not support inputs with higher dimensions.

As illustrated in the figure below, the sorting is conducted on a two-dimensional matrix with a shape of (4, 32) and with k set to 1, and yields the output result of [[32] [32] [32] [32]].

  • Essential concepts

    Based on the preceding example, some essential concepts need to be put forth. The number of rows is referred to as the outer axis length (outter), and the actual number of elements in each row is dubbed the actual inner axis length (n). This API requires that the input inner axis length be an integer multiple of 32. If n is not a multiple of 32, you need to pad it to an integer multiple of 32. The length after padding is called the inner axis length (inner). For example, in the following example, the actual length n of each row is 31, which is not an integer multiple of 32. After the length is padded up, the inner becomes 32. padding in the figure indicates the padding operation. The relationship between n and inner is as follows: When n is an integer multiple of 32, inner = n; otherwise, inner > n.

  • API modes

    This API supports two modes: normal mode and small mode. The normal mode is a general-purpose mode. The small mode is a high-performance mode specifically designed for scenarios where the internal axis length is fixed at 32 (unit: number of elements). Due to a fixed inner length of 32, the small mode allows for more targeted processing with fewer constraints and higher performance. It is recommended that the small mode be used when the inner axis length is 32.

  • Additional function: This API allows you to specify that the sorting of certain rows is invalid. This is controlled by the passed value of finishedLocal. If the value of finishedLocal in a row is true, the sorting for that row is invalid. At this point, all k index values in dstIndexLocal output after sorting will be set to the invalid index n.

Principles

The figure below illustrates the internal algorithm block diagram of TopK high-level APIs, taking a float input tensor in ND format with a shape of [outter, inner] as an example.

There are two branches based on the selection of TopK modes.

  • The TopK normal mode is computed as follows:
    1. If the template parameter isInitIndex is set to false, generate an index from 0 to inner – 1.
      • Method 1: Use CreateVecIndex to generate an index from 0 to inner – 1.
      • Method 2: Use ArithProgression to generate an index from 0 to inner – 1.
    2. If isLargest is set to false, multiply data by –1 because the Sort32 instruction sorts data in descending order by default.
    3. Sort the input data.
      • Method 1:

        Use the high-level API Sort to sort data.

      • Method 2:
        1. Use Sort32 to sort data to ensure that every 32 pieces of data are sorted in order.
        2. Use the MrgSort instruction to merge and sort all sorted data blocks.
    4. Use the GatherMask instruction to extract the first k pieces of data and their indexes.
    5. If isfinishLocal is set to true, update all indexes to n.
    6. If isLargest is set to false, multiply data by –1 to restore the data.
  • The TopK small mode is computed as follows:
    1. If the template parameter isInitIndex is set to false, generate an index from 0 to inner – 1, and use the Copy instruction to copy the data to outter pieces.
      • Method 1: Use CreateVecIndex to generate an index from 0 to inner – 1.
      • Method 2: Use ArithProgression to generate an index from 0 to inner – 1.
    2. If isLargest is set to false, multiply data by –1 because the Sort32 instruction sorts input data in descending order by default.
    3. Use Sort32 to sort data.
    4. Use the GatherMask instruction to extract the first k pieces of data and their indexes.
    5. If isLargest is set to false, multiply the input data by –1 to restore the data.

Prototype

  • Allocate the temporary space inside the API.
    1
    2
    template <typename T, bool isInitIndex = false, bool isHasfinish = false, bool isReuseSrc = false, enum TopKMode topkMode = TopKMode::TOPK_NORMAL>
    __aicore__ inline void TopK(const LocalTensor<T> &dstValueLocal, const LocalTensor<int32_t> &dstIndexLocal, const LocalTensor<T> &srcLocal, const LocalTensor<int32_t> &srcIndexLocal, const LocalTensor<bool> &finishLocal, const int32_t k, const TopkTiling &tilling, const TopKInfo &topKInfo, const bool isLargest = true)
    
  • Pass the temporary space through the tmpLocal input parameter.
    1
    2
    template <typename T, bool isInitIndex = false, bool isHasfinish = false, bool isReuseSrc = false, enum TopKMode topkMode = TopKMode::TOPK_NORMAL>
    __aicore__ inline void TopK(const LocalTensor<T> &dstValueLocal, const LocalTensor<int32_t> &dstIndexLocal, const LocalTensor<T> &srcLocal, const LocalTensor<int32_t> &srcIndexLocal, const LocalTensor<bool> &finishLocal, const LocalTensor<uint8_t> &tmpLocal, const int32_t k, const TopkTiling &tilling, const TopKInfo &topKInfo, const bool isLargest = true)
    

Due to the complex logical computation involved in the internal implementation of this API, additional temporary space is required to store intermediate variables generated during computation. The temporary space can be allocated inside the API or passed by developers through the tmpLocal input parameter.

  • When the temporary space is allocated inside the API, you do not need to allocate the space, but must reserve the required size for the space.
  • When the tmpLocal input parameter is used for passing the temporary space, the tensor serves as the temporary space. In this case, the API is not required for temporary space allocation. This enables you to manage the tmpLocal space and reuse the buffer after calling the API, so that the buffer is not repeatedly allocated and deallocated, improving the flexibility and buffer utilization. The method of obtaining the BufferSize of the temporary space size (tmpLocal) is as follows. Obtain the required maximum and minimum temporary space sizes through the GetTopKMaxMinTmpSize API provided in TopK Tiling.

Parameters

Table 1 Parameters in the template

API

Function

T

isInitIndex

Whether to pass the index of input data.

  • If this parameter is set to true, the index of input data needs to be passed through the srcIndexLocal parameter. For details about the rules, see the description of the srcIndexLocal parameter in Table 2.
  • The value false indicates that the index is not passed and is directly generated by the TopK API.

isHasfinish

The TopK API allows you to specify that the sorting of certain rows is invalid using the finishedLocal parameter. This template parameter determines whether to enable the preceding function. The value true indicates that the function is enabled, and the value false indicates that the function is disabled.

The value can be true or false in normal mode.

The value can only be false in small mode.

For details about how to use isHasfinish and finishedLocal together, see the description of finishedLocal in Table 2.

isReuseSrc

Whether the source operand can be modified. This parameter is reserved. Pass the default value false.

topkMode

TopK mode selection. The data structure is as follows:

1
2
3
4
enum class TopKMode {
    TOPK_NORMAL, // Normal mode
    TOPK_NSMALL, // Small mode
};
Table 2 API parameters

Parameter

Input/Output

Description

dstValueLocal

Output

Destination operand, which is used to store the sorted k values.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

The data type of the destination operand must be the same as that of the source operand srcLocal.

Normal mode:

  • The output shape is outter x k_pad, which means that there are outter pieces of data output and each piece has a length of k_pad. k_pad is the value obtained by padding k upwards to be 32-byte aligned based on the input data type.
  • You need to allocate a space with the size of k_pad * outter * sizeof(T) for dstValueLocal.
  • The first k values output for each data piece are the first k maximum or minimum values of the data piece. The k + 1 to k_pad elements of each data piece are not filled with values, but with some random values.
  • k_pad is calculated as follows:
    1
    2
    3
    4
    5
    6
    7
    if (sizeof(T) == sizeof(float)) {
        // When the type of inputs srcLocal and dstValueLocal is float, which has 4 bytes, round up k to the nearest multiple of 8 to obtain k_pad, thus achieving 32-byte alignment.
        k_pad = (k + 7) / 8 * 8;
    } else {
       // When the type of inputs srcLocal and dstValueLocal is half, which has 2 bytes, round up k to the nearest multiple of 16 to obtain k_pad, thus achieving 32-byte alignment.
        k_pad = (k + 15) / 16 * 16;
    }
    

Small mode:

  • The output shape is outter * k, which means that there are outter pieces of data output and each piece has a length of k.
  • The output values need to be saved in a space with a size of k * outter * sizeof(T). You need to allocate the actual buffer space to dstValueLocal based on this size and the framework's alignment requirements.
    NOTE:

    The size of allocated buffer must be 32-byte aligned according to the framework's requirements. If the value of k * outter * sizeof(T) is not 32-byte aligned, it should be rounded up to the nearest multiple of 32 bytes. The extra buffer space allocated for alignment purposes should not be filled with values, but rather left with random values.

dstIndexLocal

Output

Destination operand, which is used to store indexes corresponding to the sorted k values.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

Normal mode:

  • The output shape is outter *kpad_index, which means that there are outter pieces of data output and each piece has a length of kpad_index. kpad_index is the value obtained by padding k upwards to be 32-byte aligned based on the input index type.
  • You need to allocate a space with the size of kpad_index * outter * sizeof(int32_t) for dstIndexLocal.
  • The first k values of each data piece are indexes corresponding to the first k maximum/minimum values of the data piece. The k + 1 to kpad_index indexes of each data piece are not filled with values, but with some random values.
  • k_pad is calculated as follows:
    1
    2
    // Since the type of dstIndexLocal is int32_t, which has 4 bytes, round up k to the nearest multiple of 8 to obtain kpad_index, thus achieving 32-byte alignment.
    kpad_index = (k + 7) / 8 * 8;
    

Small mode:

  • The output shape is outter *k, which means that there are outter pieces of data output and each piece has a length of k.
  • The output indexes need to be saved in a space with a size of k * outter * sizeof(int32_t). You need to allocate the actual buffer space to dstIndexLocal based on this size and the framework's alignment requirements.
    NOTE:

    The size of allocated buffer must be 32-byte aligned according to the framework's requirements. If the value of k * outter * sizeof(int32_t) is not 32-byte aligned, it should be rounded up to the nearest multiple of 32 bytes. The extra buffer space allocated for alignment purposes should not be filled with values, but rather left with random values.

srcLocal

Input

Source operand, which is used to store the values to be sorted.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

srcIndexLocal

Input

Source operand, which is used to store indexes corresponding to the values to be sorted.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

This parameter works in conjunction with the template parameter isInitIndex. If isInitIndex is set to false, then srcIndexLocal is only defined without being assigned a value and is subsequently passed to the API. However, if isInitIndex is set to true, the index values need to be passed through srcIndexLocal. The rules for setting srcIndexLocal are as follows:

Normal mode:

  • The shape of the input index data is 1 * inner. Here, the outter pieces of data use the same index. You need to allocate a space with the size of inner * sizeof(int32_t) for it.
  • When n is less than inner, you need to pad the index data from n to the length of inner.
  • Padding rule: The padded indexes cannot affect the overall sorting. You are advised to use the following padding method: The padded values are incremented based on the original indexes. For example, if the original indexes are 0, 1, 2, ..., n – 1, the padded indexes are 0, 1, 2, ..., n, n + 1, ..., inner – 1.

Small mode:

  • The shape of input index data is outter * inner. You need to allocate a space with the size of outter * inner * sizeof(int32_t) for it.
  • When n is less than 32, you need to pad the outter pieces of data, with each data piece padded from n to the length of 32.
  • Padding rule: The padded data cannot affect the overall sorting. You are advised to use the following padding method: The padded values are incremented based on the original indexes. For example, if the original indexes are 0, 1, 2, ..., n-1, the padded indexes are 0, 1, 2, ..., n, n + 1, ..., inner-1.

finishedLocal

Input

Source operand, which is used to specify that the sorting of certain rows is invalid. Its shape is (outter, 1).

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

This parameter works in conjunction with the template parameter isHasfinish. In normal mode, isHasfinish can be set to true or false. In small mode, isHasfinish can only be set to false.

  • isHasfinish is set to true.
    • If the value of finishedLocal in the outter row is true, the sorting of this row is invalid. After the sorting, all the k index values of dstIndexLocal are set to n.
    • If the value of finishedLocal in the outter row is false, the sorting of this row is valid.
  • When isHasfinish is set to false, finishedLocal only needs to be defined without being assigned a value and is subsequently passed to the API. A definition example is as follows:
    1
    LocalTensor<bool> finishedLocal;
    

tmpLocal

Input

Temporary space. This parameter is used to store intermediate variables during complex internal API computation and is provided by developers. The data type is fixed at uint8_t.

The type is LocalTensor, and the logical location can only be VECCALC.

For details about how to obtain the temporary space size (BufferSize), see TopK Tiling.

k

Input

Obtains the first k maximum or minimum values and their corresponding indexes. The data type is int32_t.

The value of k must meet the following condition: 1 ≤ kn.

tiling

Input

Tiling information required for TopK computation. For details about how to obtain the tiling information, see TopK Tiling.

topKInfo

Input

Shape of srcLocal, TopKInfo type. The specific definition is as follows:

1
2
3
4
5
struct TopKInfo {
    int32_t outter = 1;    // outer axis length of the input data to be sorted.
    int32_t inner;         // inner axis length of the input data to be sorted. The value of inner must be an integer multiple of 32.
    int32_t n;             // actual inner axis length of the input data to be sorted.
};
  • The value of topKInfo.inner must be an integer multiple of 32.
  • topKInfo.inner is the value obtained by padding topKInfo.n up to the nearest 32-aligned integer. Therefore, the value of topKInfo.n must meet the following condition: 1 ≤ topKInfo.ntopKInfo.inner.
  • In small mode, topKInfo.inner must be set to 32.
  • In normal mode, the maximum value of topKInfo.inner is 4096.

isLargest

Input

The value is of the bool type. If this parameter is set to true, data is sorted in descending order by default and the first k maximum values are obtained. If this parameter is set to false, data is sorted in ascending order and the first k minimum values are obtained.

Returns

None

Availability

Precautions

  • For details about the alignment requirements of the operand address offset, see General Restrictions.
  • The source operand address must not overlap the destination operand address.
  • If srcLocal[i] and srcLocal[j] are the same and i is greater than j, srcLocal[j] is selected first.
  • inf is considered to be the maximum value in TopK.
  • nan is placed at the front in TopK sorting regardless of whether the data is sorted in descending or ascending order.

Example

This sample implements the code logic of normal and small modes.

Table 3 Normal mode sample analysis

Sample Description

This sample demonstrates the sorting of a float matrix with a shape of (2, 32). It calculates the first five minimum values of each row in the matrix.

If the API in normal mode is used, you need to pass the input data indexes, and also finishedLocal to specify that the sorting of certain rows is invalid.

Inputs

  • Template parameter (T): float
  • Template parameter (isInitIndex): true
  • Template parameter (isHasfinish): true
  • Template parameter (topkMode): TopKMode::TOPK_NORMAL
  • Input data (finishLocal):
    1
    2
    3
    [False  True  False  False False False  False False False  False False False 
     False  False False False False False False False False False False False 
     False False False False False False False False]
    

    Note: The data movement amount of DataCopy must be a multiple of 32 bytes. Therefore, the actual valid inputs of finishLocal are False and True. The remaining values are padded values used for 32-byte alignment and do not actually participate in calculation.

  • Input data (k): 5
  • Input data (topKInfo):
    1
    2
    3
    4
    5
    struct TopKInfo {
        int32_t outter = 2;
        int32_t inner = 32;
        int32_t n = 32;
    };
    
  • Input data (isLargest): false
  • Input data (srcLocal):
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    [[-18096.555   -11389.83    -43112.895   -21344.77     57755.918
       50911.145    24912.621   -12683.089    45088.004   -39351.043
      -30153.293    11478.329    12069.15     -9215.71     45716.44
      -21472.398   -37372.16    -17460.414    22498.03     21194.838
      -51229.17    -51721.918   -47510.38     47899.11     43008.176
        5495.8975  -24176.97    -14308.27     53950.695     7652.6035
      -45169.168   -26275.518  ]
     [ -9196.681   -31549.518    18589.23    -12427.927    50491.81
      -20078.11    -25606.107   -34466.773   -42512.805    50584.48
       35919.934   -17283.5       6488.137   -12885.134     1942.2147
      -50611.96     52671.477    23179.662    25814.875      -69.73492
       33906.797   -34662.61     46168.71    -52391.258    57435.332
       50269.414    40935.05     21164.176     4028.458   -29022.918
      -46391.133     1971.2042 ]]
    
  • Input data (srcIndexLocal):
    1
    2
    [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
     24 25 26 27 28 29 30 31]
    

Output Data

  • The output data (dstValueLocal) is as follows. The length of each row is k_pad. The first five values of each data piece are the first five minimum values of that piece. The last three values are random values.
    1
    2
    [[-51721.918 -51229.17  -47510.38  -45169.168 -43112.895     0.     0.   0.   ]
     [-52391.258 -50611.96  -46391.133 -42512.805 -34662.61      0.     0.   0.   ]]
    
  • The output data (dstIndexLocal) is as follows:
    • The length of each row is kpad_index. The first five values of each data piece are the indexes corresponding to the first five minimum values of that piece. The last three values are random values.
    • Since finishLocal of the second row of data is true, the sorting of the second row of data is invalid. Consequently, the output index values are all equal to the actual length of the inner axis, which is 32.
      1
      2
      [[21 20 22 30  2  0  0  0]
       [32 32 32 32 32  0  0  0]]
      
Table 4 Small model sample analysis

Sample Description

This sample demonstrates the sorting of float input data with a shape of (4, 17). It calculates the first eight maximum values of each data row.

If the API in small mode is used, you need to pass the input data indexes.

Inputs

  • Template parameter (T): float
  • Template parameter (isInitIndex): true
  • Template parameter (isHasfinish): false
  • Template parameter (topkMode): TopKMode::TOPK_NSMALL
  • Input data (finishLocal): LocalTensor<bool> finishedLocal. No value needs to be assigned.
  • Input data (k): 8
  • Input data (topKInfo):
    1
    2
    3
    4
    5
    struct TopKInfo {
        int32_t outter = 4;
        int32_t inner = 32;
        int32_t n = 17;
    };
    
  • Input data (isLargest): true
  • Input data (srcLocal): Here n is 17 and is not an integer multiple of 32. Pad it up to 32 using -inf.
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    [[ 55492.18     27748.229   -51100.11     19276.926    14828.149
      -20771.824    57553.4     -21504.092   -57423.414      142.36443
       -5223.254    54669.473    54519.184    10165.924     -658.4564
        2264.2397  -52942.883           -inf         -inf         -inf
              -inf         -inf         -inf         -inf         -inf
              -inf         -inf         -inf         -inf         -inf
              -inf         -inf]
     [-52849.074    57778.72     37069.496    16273.109   -25150.637
      -35680.5     -15823.097     4327.308   -35853.86     -7052.2627
       44148.117   -17515.457   -18926.059    -1650.6737   21753.582
       -2589.2822   39390.4             -inf         -inf         -inf
              -inf         -inf         -inf         -inf         -inf
              -inf         -inf         -inf         -inf         -inf
              -inf         -inf]
     [-17539.186   -15220.923    29945.332    -4088.1514   28482.525
       29750.484   -46082.03     31141.16     23140.047     8461.174
       39955.844    29401.35     53757.543    33584.566    -3543.6284
      -38318.344    22212.41            -inf         -inf         -inf
              -inf         -inf         -inf         -inf         -inf
              -inf         -inf         -inf         -inf         -inf
              -inf         -inf]
     [ -9970.768    -9191.963   -17903.045     2211.4912   47037.562
      -41114.824    13305.985    59926.07    -24316.797    -6462.8896
        5699.733    -5873.5015   15695.861   -38492.004    19581.654
      -36877.68     27090.158           -inf         -inf         -inf
              -inf         -inf         -inf         -inf         -inf
              -inf         -inf         -inf         -inf         -inf
              -inf         -inf]]
    
  • Input data (srcIndexLocal):
    1
    2
    3
    4
    5
    6
    7
    8
    [[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
      24 25 26 27 28 29 30 31]
     [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
      24 25 26 27 28 29 30 31]
     [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
      24 25 26 27 28 29 30 31]
     [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
      24 25 26 27 28 29 30 31]]
    

Output Data

  • Output data (dstValueLocal): Output the first eight maximum values of each data row.
    1
    2
    3
    4
    5
    6
    7
    8
    [[57553.4    55492.18   54669.473  54519.184  27748.229  19276.926
      14828.149  10165.924 ]
     [57778.72   44148.117  39390.4    37069.496  21753.582  16273.109
       4327.308  -1650.6737]
     [53757.543  39955.844  33584.566  31141.16   29945.332  29750.484
      29401.35   28482.525 ]
     [59926.07   47037.562  27090.158  19581.654  15695.861  13305.985
       5699.733   2211.4912]]
    
  • Output data (dstIndexLocal): Output the indexes of the first eight maximum values of each data row.
    1
    2
    3
    4
    [[ 6  0 11 12  1  3  4 13]
     [ 1 10 16  2 14  3  7 13]
     [12 10 13  7  2  5 11  4]
     [ 7  4 16 14 12  6 10  3]]