GetSwiGLUTmpBufferFactorSize

Function

To perform SwiGLU computation in the kernel, developers need to reserve or allocate the temporary space. The relationship between the maximum temporary space (maxTmpBuffer) and the space occupied by the input (inputSize × typeSize) is as follows:

maxTmpBuffer = maxLiveNodeCount * inputSize * typeSize + extraBuffer

maxLiveNodeCount indicates how many times the maximum temporary space is the space occupied by the input. extraBuffer indicates the size of the extra temporary space.

This API is used to obtain maxLiveNodeCount and extraBuffer. When the fixed space size is used, the maximum number of elements that can be computed by the operator at a time can be calculated based on maxLiveNodeCount and extraBuffer.

The following is an example:

The SwiGLU API needs to be called for operator implementation. Developers need to reserve space of the currBuff size and use the GetSwiGLUTmpBufferFactorSize API to obtain the output values of maxLiveNodeCount and extraBuffer and calculate the maximum number of elements calculated by the SwiGLU operator at a time.

currentShapeSize = (currBuff - extraBuffer) / maxLiveNodeCount / typeSize

Prototype

1
void GetSwiGLUTmpBufferFactorSize(const uint32_t typeSize, uint32_t& maxLiveNodeCount, uint32_t& extraBuffer)

Parameters

Table 1 Parameters

Parameter

Input/Output

Description

typeSize

Input

Size of the input data type, in bytes. For example, if the input data type is half, set this parameter to 2.

maxLiveNodeCount

Output

Maximum number of live nodes, indicating how many times larger the maximum temporary space is compared to the space occupied by the input.

extraBuffer

Output

Size of the used extra temporary space, in bytes.

Returns

None

Restrictions

If currentShapeSize × typeSize < 256B is obtained based on maxLiveNodeCount and extraBuffer, currentShapeSize should be rounded up based on the value of 256B/typeSize.

Example

1
2
3
4
5
uint32_t typeSize = sizeof(half);
uint32_t maxLiveNodeCount = 0;
uint32_t extraBuffer = 0;

AscendC::GetSwiGLUTmpBufferFactorSize(typeSize, maxLiveNodeCount, extraBuffer);