TopK
Applicability
|
Product |
Supported |
|---|---|
|
|
√ |
|
|
√ |
|
|
x |
|
|
√ |
|
|
x |
|
|
x |
Function Usage
Obtains the first k maximum or minimum values of the last dimension and their corresponding indexes.
If the input is a vector, find the first k maximum or minimum values and their corresponding indexes within the vector. If the input is a matrix, calculate the first k maximum or minimum values and their corresponding indexes in each row along the last dimension. This API supports input of data no more than two dimensions.
As illustrated in the figure below, the sorting is conducted on a two-dimensional matrix with a shape of (4, 32) and with k set to 1, and yields the output result of [[32] [32] [32] [32]].

- Essential concepts
Based on the preceding example, some essential concepts need to be put forth. The number of rows is referred to as the outer axis length (outter), and the actual number of elements in each row is dubbed the actual inner axis length (n). This API requires that the input inner axis length be an integer multiple of 32. If n is not a multiple of 32, you need to pad it to an integer multiple of 32. The length after padding is called the inner axis length (inner). For example, in the following example, the actual length n of each row is 31, which is not an integer multiple of 32. After the length is padded, the inner becomes 32. padding in the figure indicates the padding operation. The relationship between n and inner is as follows: When n is an integer multiple of 32, inner = n; otherwise, inner > n.

- API modes
This API supports two modes: normal mode and small mode. The normal mode is a general-purpose mode. The small mode is a high-performance mode specifically designed for scenarios where the internal axis length is fixed at 32 (unit: number of elements). Due to a fixed inner length of 32, the small mode allows for more targeted processing with fewer constraints and higher performance. It is recommended that the small mode be used when the inner axis length is 32.
- Additional function: This API allows you to specify that the sorting of certain rows is invalid. This is controlled by the passed value of finishLocal. If the value of finishLocal in a row is true, the sorting for that row is invalid. At this point, all k index values in dstIndexLocal output after sorting will be set to the invalid index n.

Principles
- MERGE_SORT algorithm
The figure below illustrates the internal algorithm block diagram of TopK high-level APIs, taking a float input tensor in ND format with a shape of [outter, inner] as an example.
Figure 1 TopK algorithm block diagram
There are two branches based on the selection of TopK modes.
- The TopK normal mode is computed as follows:
- If the template parameter isInitIndex is set to false, generate an index from 0 to inner –1.
For the
Atlas A3 training products /Atlas A3 inference products , method 2 is used.For the
Atlas A2 training products /Atlas A2 inference products , method 2 is used.For the
Atlas inference products , method 2 is used.- Method 1: Use CreateVecIndex to generate an index from 0 to inner –1.
- Method 2: Use Arange to generate the indexes from 0 to inner – 1.
- If isLargest is set to false, multiply data by –1 because the Sort32 instruction sorts data in descending order by default.
- Sort the input data.
For the
Atlas A3 training products /Atlas A3 inference products , method 2 is used.For the
Atlas A2 training products /Atlas A2 inference products , method 2 is used.For the
Atlas inference products , method 2 is used.Method 1:
Use the high-level API Sort to sort data.
Method 2:- Use Sort32 to sort data to ensure that every 32 pieces of data are sorted in order.
- Use the MrgSort instruction to merge and sort all sorted data blocks.
- Use the GatherMask instruction to extract the first k pieces of data and their indexes.
- If finishLocal[i] is true, the sorting result of the corresponding row is updated to the invalid index n.
- If isLargest is set to false, multiply data by –1 to restore the data.
Note: On the Atlas Inference Series Product, use the ProposalConcat basic API to combine data and indexes, and use the RpSort16 basic API to sort the combined data. Then, use MrgSort4 to merge the sorted data. Finally, use the ProposalExtract basic API to extract the sorted data and indexes.
- If the template parameter isInitIndex is set to false, generate an index from 0 to inner –1.
- The TopK small mode is computed as follows:
- If the template parameter isInitIndex is set to false, generate an index from 0 to inner –1, and use the Copy instruction to copy the data to outter pieces.
For the
Atlas A3 training products /Atlas A3 inference products , method 2 is used.For the
Atlas A2 training products /Atlas A2 inference products , method 2 is used.For the
Atlas inference products , method 2 is used.- Method 1: Use CreateVecIndex to generate an index from 0 to inner –1.
- Method 2: Use Arange to generate indexes from 0 to inner – 1.
- If isLargest is set to false, multiply data by –1 because the Sort32 instruction sorts input data in descending order by default.
- Use Sort32 to sort data.
- Use the GatherMask instruction to extract the first k pieces of data and their indexes.
- If isLargest is set to false, multiply the input data by –1 to restore the data.
Note: On the
Atlas inference products , use the ProposalConcat basic API to combine data and indexes, and use the RpSort16 basic API to sort the combined data. Because the inner is 32 under small mode and every 16 pieces of data are ordered after RpSort16 sorting, the MrgSort4 basic API is used between steps 3 and 4 to perform merging and sorting once. - If the template parameter isInitIndex is set to false, generate an index from 0 to inner –1, and use the Copy instruction to copy the data to outter pieces.
- The TopK normal mode is computed as follows:
Prototype
- Allocate the temporary space inside the API.
1 2
template <typename T, bool isInitIndex = false, bool isHasfinish = false, bool isReuseSrc = false, enum TopKMode topkMode = TopKMode::TOPK_NORMAL> __aicore__ inline void TopK(const LocalTensor<T>& dstValueLocal, const LocalTensor<int32_t>& dstIndexLocal, const LocalTensor<T>& srcLocal, const LocalTensor<int32_t>& srcIndexLocal, const LocalTensor<bool>& finishLocal, const int32_t k, const TopkTiling& tilling, const TopKInfo& topKInfo, const bool isLargest = true)
- Pass the temporary space through the tmpLocal input parameter.
1 2
template <typename T, bool isInitIndex = false, bool isHasfinish = false, bool isReuseSrc = false, enum TopKMode topkMode = TopKMode::TOPK_NORMAL> __aicore__ inline void TopK(const LocalTensor<T>& dstValueLocal, const LocalTensor<int32_t>& dstIndexLocal, const LocalTensor<T>& srcLocal, const LocalTensor<int32_t>& srcIndexLocal, const LocalTensor<bool>& finishLocal, const LocalTensor<uint8_t>& tmpLocal, const int32_t k, const TopkTiling& tilling, const TopKInfo& topKInfo, const bool isLargest = true)
Due to the complex logical computation involved in the internal implementation of this API, additional temporary space is required to store intermediate variables generated during computation. The temporary space can be allocated inside the API or passed by developers through the tmpLocal input parameter.
- When the temporary space is allocated inside the API, you do not need to allocate the space, but must reserve the required size for the temporary space.
- When the tmpLocal input parameter is used for passing the temporary space, the tensor serves as the temporary space. In this case, the API is not required for temporary space allocation. This enables you to manage the tmpLocal space and reuse the buffer after calling the API, so that the buffer is not repeatedly allocated and deallocated, improving the flexibility and buffer utilization. The method of obtaining the BufferSize of the temporary space size (tmpLocal) is as follows. Obtain the required maximum and minimum temporary space sizes through the GetTopKMaxMinTmpSize API provided in TopK Tiling.
Parameters
|
Parameter |
Description |
||
|---|---|---|---|
|
T |
Type of the data to be sorted. For the For the For the |
||
|
isInitIndex |
Whether to pass the index of input data.
|
||
|
isHasfinish |
The TopK API allows you to specify that the sorting of certain rows is invalid using the finishLocal parameter. This template parameter determines whether to enable the preceding function. The value true indicates that the function is enabled, and the value false indicates that the function is disabled. The value can be true or false in normal mode. The value can only be false in small mode. For details about how to use isHasfinish and finishLocal together, see the description of finishLocal in Table 2. |
||
|
isReuseSrc |
Whether the source operand can be modified. This parameter is reserved. Pass the default value false. |
||
|
topkMode |
TopK mode selection. The data structure is as follows:
|
|
Parameter |
Input/Output |
Description |
||
|---|---|---|---|---|
|
dstValueLocal |
Output |
Destination operand, which is used to store the sorted k values. The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT. Normal mode:
Small mode:
|
||
|
dstIndexLocal |
Output |
Destination operand, which is used to store indexes corresponding to the sorted k values. The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT. Normal mode:
Small mode:
|
||
|
srcLocal |
Input |
Source operand, which is used to store the values to be sorted. The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.
On the
|
||
|
srcIndexLocal |
Input |
Source operand, which is used to store indexes corresponding to the values to be sorted. The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT. This parameter works in conjunction with the template parameter isInitIndex. If isInitIndex is set to false, then srcIndexLocal is only defined without being assigned a value and is subsequently passed to the API. However, if isInitIndex is set to true, the index values need to be passed through srcIndexLocal. The rules for setting srcIndexLocal are as follows: Normal mode:
Small mode:
|
||
|
finishLocal |
Input |
Source operand, which is used to specify that the sorting of certain rows is invalid. Its shape is (outter, 1). The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT. This parameter works in conjunction with the template parameter isHasfinish. In normal mode, isHasfinish can be set to true or false. In small mode, isHasfinish can only be set to false.
|
||
|
tmpLocal |
Input |
Temporary space. This parameter is used to store intermediate variables during complex internal computation of the API. The temporary space is provided by developers. For details about how to obtain the size of the temporary space, see TopK Tiling. The data type is fixed at uint8_t. The type is LocalTensor, and the logical location can only be VECCALC. |
||
|
k |
Input |
The first k maximum or minimum values and their corresponding indexes. The data type is int32_t. The value of k must meet the following condition: 1 ≤ k ≤ n. |
||
|
tilling |
Input |
Tiling information required for TopK computation. For details about how to obtain the tiling information, see TopK Tiling. |
||
|
topKInfo |
Input |
Shape of srcLocal, TopKInfo type. The specific definition is as follows:
|
||
|
isLargest |
Input |
The value is of the bool type. If this parameter is set to true, data is sorted in descending order by default and the first k maximum values are obtained. If this parameter is set to false, data is sorted in ascending order and the first k minimum values are obtained. |
Returns
None
Restrictions
- For details about the alignment requirements of the operand address offset, see General Description and Restrictions.
- The source operand address must not overlap the destination operand address.
- If srcLocal[i] and srcLocal[j] are the same and i is greater than j, srcLocal[j] is selected first.
- inf is considered to be the maximum value in TopK.
- nan is placed at the front in TopK sorting regardless of whether the data is sorted in descending or ascending order.
- For the
Atlas inference product 's AI Core:- If the type of srcLocal is half and the template parameter isInitIndex is set to false, the passed topKInfo.inner cannot be greater than 2048.
- If the type of srcLocal is half and the template parameter isInitIndex is set to true, the index value of srcIndexLocal cannot be greater than 2048.
Example
This sample implements the code logics of the normal and small modes. To obtain an operator sample project, click sort.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
if (!tmpLocal) { // Whether to pass data to the temporary space through the tmpLocal input parameter. if (isSmallMode) { // Small mode AscendC::TopK<T, isInitIndex, isHasfinish, isReuseSrc, AscendC::TopKMode::TOPK_NSMALL>(dstLocalValue, dstLocalIndex, srcLocalValue, srcLocalIndex, srcLocalFinish, k, topKTilingData, topKInfo, isLargest); } else { AscendC::TopK<T, isInitIndex, isHasfinish, isReuseSrc, AscendC::TopKMode::TOPK_NORMAL>(dstLocalValue, dstLocalIndex, srcLocalValue, srcLocalIndex, srcLocalFinish, k, topKTilingData, topKInfo, isLargest); } } else { if (tmplocalBytes % 32 != 0) { tmplocalBytes = (tmplocalBytes + 31) / 32 * 32; } pipe.InitBuffer(tmplocalBuf, tmplocalBytes); AscendC::LocalTensor<uint8_t> tmplocalTensor = tmplocalBuf.Get<uint8_t>(); if (isSmallMode) { AscendC::TopK<T, isInitIndex, isHasfinish, isReuseSrc, AscendC::TopKMode::TOPK_NSMALL>(dstLocalValue, dstLocalIndex, srcLocalValue, srcLocalIndex, srcLocalFinish, tmplocalTensor, k, topKTilingData, topKInfo, isLargest); } else { AscendC::TopK<T, isInitIndex, isHasfinish, isReuseSrc, AscendC::TopKMode::TOPK_NORMAL>(dstLocalValue, dstLocalIndex, srcLocalValue, srcLocalIndex, srcLocalFinish, tmplocalTensor, k, topKTilingData, topKInfo, isLargest); } } |
|
Sample Description |
This sample demonstrates the sorting of a float matrix with a shape of (2, 32). It calculates the first five minimum values of each row in the matrix. If the API in normal mode is used, you need to pass the input data indexes, and also finishLocal to specify that the sorting of certain rows is invalid. |
||||||||
|---|---|---|---|---|---|---|---|---|---|
|
Inputs |
|
||||||||
|
Output Data |
|
|
Sample Description |
This sample demonstrates the sorting of float input data with a shape of (4, 17). It calculates the first eight maximum values of each data row. If the API in small mode is used, you need to pass the input data indexes. |
||||||
|---|---|---|---|---|---|---|---|
|
Inputs |
|
||||||
|
Output Data |
|