SetSelfDefineData

Function Description

Sets information such as the required computation data or data address stored on GM for the callback function when the template parameter MatmulCallBack (customized callback function) is enabled. This API can be called for multiple times.

Prototype

1
__aicore__ inline void SetSelfDefineData(const uint64_t dataPtr)

Parameters

Table 1 Parameters

Parameter

Input/Output

Description

dataPtr

Input

Computation data required by the operator callback function or data address information stored on GM

Returns

None

Availability

Precautions

If dataPtr is required in the callback function, this API must be called. If it is not required, this API does not need to be called.

Example

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
// User-defined callback function
void DataCopyOut(const __gm__ void *gm, const LocalTensor<int8_t> &co1Local, const void *dataCopyOutParams, const uint64_t tilingPtr, const uint64_t dataPtr);
void CopyA1(const LocalTensor<int8_t> &aMatrix, const __gm__ void *gm, int row, int col, int useM, int useK, const uint64_t tilingPtr, const uint64_t dataPtr);
void CopyB1(const LocalTensor<int8_t> &bMatrix, const __gm__ void *gm, int row, int col, int useK, int useN, const uint64_t tilingPtr, const uint64_t dataPtr);

typedef matmul::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, half> aType; 
typedef matmul::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, half> bType; 
typedef matmul::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, float> cType; 
typedef matmul::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, float> biasType; 
matmul::Matmul<aType, bType, cType, biasType, CFG_NORM, MatmulCallBackFunc<DataCopyOut, CopyA1, CopyB1>> mm;
REGIST_MATMUL_OBJ(&pipe, GetSysWorkSpacePtr(), mm, &tiling);
GlobalTensor<SrcT> dataGM; // Store the GM of the computation data required by the callback function.
uint64_t dataGMPtr = reinterpret_cast<uint64_t>(dataGM.address_);
mm.SetSelfDefineData(dataGMPtr);
mm.SetTensorA(gmA);
mm.SetTensorB(gmB);
mm.IterateAll();