TopK Tiling
Function Usage
Obtains the TopK tiling parameters.
Ascend C provides the TopK tiling API, which helps you obtain the tiling parameters required for TopK kernel computation. Before reading this section, refer to Tiling Implementation to learn the basic tiling process.
To obtain the tiling parameters, perform the following two steps:
- Obtain the minimum and maximum temporary space sizes required for TopK API computation. Note that this step is not mandatory and only serves as a reference for appropriately allocating computing space.
- Obtain the tiling parameters required by the TopK kernel API.
Below is the definition of the GroupNorm tiling structure. You do not need to pay attention to the specific information of this tiling structure. They only need to pass it to the kernel and directly use it through GroupNorm high-level APIs.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
struct TopkTiling { int32_t tmpLocalSize = 0; int32_t allDataSize = 0; int32_t innerDataSize = 0; uint32_t sortRepeat = 0; int32_t mrgSortRepeat = 0; int32_t kAlignFourBytes = 0; int32_t kAlignTwoBytes = 0; int32_t maskOffset = 0; int32_t maskVreducev2FourBytes = 0; int32_t maskVreducev2TwoBytes = 0; int32_t mrgSortSrc1offset = 0; int32_t mrgSortSrc2offset = 0; int32_t mrgSortSrc3offset = 0; int32_t mrgSortTwoQueueSrc1Offset = 0; int32_t mrgFourQueueTailPara1 = 0; int32_t mrgFourQueueTailPara2 = 0; int32_t srcIndexOffset = 0; uint32_t copyUbToUbBlockCount = 0; int32_t topkMrgSrc1MaskSizeOffset = 0; int32_t topkNSmallSrcIndexOffset = 0; uint32_t vreduceValMask0 = 0; uint32_t vreduceValMask1 = 0; uint32_t vreduceIdxMask0 = 0; uint32_t vreduceIdxMask1 = 0; uint16_t vreducehalfValMask0 = 0; uint16_t vreducehalfValMask1 = 0; uint16_t vreducehalfValMask2 = 0; uint16_t vreducehalfValMask3 = 0; uint16_t vreducehalfValMask4 = 0; uint16_t vreducehalfValMask5 = 0; uint16_t vreducehalfValMask6 = 0; uint16_t vreducehalfValMask7 = 0; };
Prototype
1
|
bool GetTopKMaxMinTmpSize(const platform_ascendc::PlatformAscendC& ascendcPlatform, const int32_t inner, const int32_t outter, const bool isReuseSource, const bool isInitIndex, enum TopKMode mode, const bool isLargest, const uint32_t dataTypeSize, uint32_t& maxValue, uint32_t& minValue) |
1
|
bool TopKTilingFunc(const platform_ascendc::PlatformAscendC& ascendcPlatform, const int32_t inner, const int32_t outter, const int32_t k, const uint32_t dataTypeSize, const bool isInitIndex, enum TopKMode mode, const bool isLargest, optiling::TopkTiling& topKTiling) |
1
|
bool TopKTilingFunc(const platform_ascendc::PlatformAscendC& ascendcPlatform, const int32_t inner, const int32_t outter, const int32_t k, const uint32_t dataTypeSize, const bool isInitIndex, enum TopKMode mode, const bool isLargest, AscendC::tiling::TopkTiling& topKTiling) |
Parameters
|
Parameter |
Input/Output |
Function |
|---|---|---|
|
ascendcPlatform |
Input |
Hardware platform information that is passed. For details about the definition of PlatformAscendC, see Constructor and Destructor. |
|
inner |
Input |
It indicates the inner axis length of the srcLocal input of the TopK API. The value of this parameter must be an integer multiple of 32. |
|
outter |
Input |
Outer axis length of the srcLocal input of the TopK interface. |
|
isReuseSource |
Input |
Whether intermediate variables can reuse the input buffer. The value must be the same as the value of isReuseSrc in the kernel. |
|
isInitIndex |
Input |
Whether to pass the indexes corresponding to input data. The value must be consistent with that of kernel APIs. |
|
mode |
Input |
Selection between TopKMode::TOPK_NORMAL or TopKMode::TOPK_NSMALL. The value must be consistent with that of kernel APIs. |
|
isLargest |
Input |
Descending or ascending order. The value true indicates the descending order, and the value false indicates the ascending order. The value must be consistent with that of kernel APIs. |
|
dataTypeSize |
Input |
Size of the srcLocal data type involved in computation, for example, half = 2 and float = 4. |
|
maxValue |
Output |
Maximum size of the temporary space required for computation inside the TopK API, in bytes.
NOTE:
maxValue is for reference only and may be larger than the available space of the Unified Buffer. In this case, select a proper temporary space size based on the available space of the Unified Buffer. |
|
minValue |
Output |
Minimum size of the temporary space required for computation inside the TopK API, in bytes. |
|
Parameter |
Input/Output |
Function |
|---|---|---|
|
ascendcPlatform |
Input |
Hardware platform information that is passed. For details about the definition of PlatformAscendC, see Constructor and Destructor. |
|
inner |
Input |
Inner axis length of the srcLocal input of the TopK API. The value is an integer multiple of 32. |
|
outter |
Input |
Outer axis length of the srcLocal input of the TopK API. |
|
k |
Input |
The first k maximum or minimum values and their corresponding indexes. |
|
dataTypeSize |
Input |
Size of the srcLocal data type involved in computation, for example, half = 2 and float = 4. |
|
isInitIndex |
Input |
Whether to pass the indexes corresponding to input data. The value must be consistent with that of kernel APIs. |
|
mode |
Input |
Selection between TopKMode::TOPK_NORMAL or TopKMode::TOPK_NSMALL. The value must be consistent with that of kernel APIs. |
|
isLargest |
Input |
Descending or ascending order. The value true indicates the descending order, and the value false indicates the ascending order. The value must be consistent with that of kernel APIs. |
|
topKTiling |
Output |
Outputs the tiling information required by the TopK API. |
Returns
The return value of GetTopKMaxMinTmpSize is either true or false. true indicates that the maximum and minimum temporary space sizes required for TopK internal computation are successfully obtained. false indicates that the sizes fail to be obtained.
The return value of TopKTilingFunc is either true or false. true indicates that the tiling parameter values of TopK are successfully obtained. false indicates that the values fail to be obtained.
Restrictions
None
Example
The following example describes the process of obtaining the tiling parameters on the host and the method of using the parameters on the kernel when Transpose high-level APIs are used.
- Add the RmsNormTiling structure parameter to the TilingData structure to function as a field.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
namespace optiling { BEGIN_TILING_DATA_DEF(TilingData) TILING_DATA_FIELD_DEF(uint32_t, totalLength); TILING_DATA_FIELD_DEF(uint32_t, tilenum); // Add other tiling fields. ... TILING_DATA_FIELD_DEF(int32_t, k); TILING_DATA_FIELD_DEF(bool, islargest); TILING_DATA_FIELD_DEF(bool, isinitindex); TILING_DATA_FIELD_DEF(bool, ishasfinish); TILING_DATA_FIELD_DEF(uint32_t, tmpsize); TILING_DATA_FIELD_DEF(int32_t, outter); TILING_DATA_FIELD_DEF(int32_t, inner); TILING_DATA_FIELD_DEF(int32_t, n); TILING_DATA_FIELD_DEF(int32_t, order); TILING_DATA_FIELD_DEF(int32_t, sorted); TILING_DATA_FIELD_DEF_STRUCT(TopkTiling, topkTilingData); END_TILING_DATA_DEF; REGISTER_TILING_DATA_CLASS(TopkCustom, TilingData) }
- The tiling implementation function first calls the GetTopKMaxMinTmpSize API to obtain the maximum and minimum temporary space sizes required by the TopK API to complete computation, sets an appropriate space size based on this range and the actual buffer usage, and then obtains the tiling parameters required by the TopK kernel API based on the input shape and other information.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
namespace optiling { const uint32_t BLOCK_DIM = 8; const uint32_t TILE_NUM = 8; const int32_t OUTTER = 2; const int32_t INNER = 32; const int32_t N = 32; const int32_t K = 8; const bool IS_LARGEST = true; const bool IS_INITINDEX = true; const bool IS_REUSESOURCE = false; static ge::graphStatus TilingFunc(gert::TilingContext* context) { TilingData tiling; uint32_t totalLength = context->GetInputTensor(0)->GetShapeSize(); context->SetBlockDim(BLOCK_DIM); tiling.set_totalLength(totalLength); tiling.set_tileNum(TILE_NUM); tiling.set_k(K); tiling.set_outter(OUTTER); tiling.set_inner(INNER); tiling.set_n(N); tiling.set_islargest(IS_LARGEST); tiling.set_isinitindex(IS_INITINDEX); // Set other tiling parameters. ... // This example is for reference only. Use GetTopKMaxMinTmpSize to obtain the minimum value and pass it to ensure correct functionality. You can pass a proper space size as required. uint32_t maxsize = 0; uint32_t minsize = 0; uint32_t dtypesize = 4; // Float type auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); AscendC::TopKTilingFunc(ascendcPlatform, tiling.inner, tiling.outter, tiling.k, dtypesize, tiling.isinitindex, AscendC::TopKMode::TOPK_NSMALL, tiling.islargest, tiling.topkTilingData); AscendC::GetTopKMaxMinTmpSize(ascendcPlatform, tiling.inner, tiling.outter, IS_REUSESOURCE, tiling.isinitindex, AscendC::TopKMode::TOPK_NSMALL, tiling.islargest, dtypesize, maxsize, minsize); tiling.set_tmpsize(minsize); ... // Other logic tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); size_t *currentWorkspace = context->GetWorkspaceSizes(1); currentWorkspace[0] = 0; return ge::GRAPH_SUCCESS; } } // namespace optiling
- On the kernel side, the tiling data is obtained by calling GET_TILING_DATA in the kernel function, and then the top-K tiling information in the tiling data is transferred to the top-K API for computation. For details about the complete example in the kernel, see Example.
1 2 3 4 5 6 7
extern "C" __global__ __aicore__ void topk_custom(GM_ADDR srcVal, GM_ADDR srcIdx, GM_ADDR finishLocal, GM_ADDR dstVal, GM_ADDR dstIdx, GM_ADDR tiling) { GET_TILING_DATA(tilingData, tiling); KernelTopK<float, true, true, false, false, AscendC::TopKMode::TOPK_NSMALL> op; op.Init(srcVal, srcIdx, finishLocal, dstVal, dstIdx, tilingData.k, tilingData.islargest, tilingData.tmpsize, tilingData.outter, tilingData.inner, tilingData.n,tilingData.topkTilingData); op.Process(); }