SetUserDefInfo

Applicability

Product

Supported

Atlas A3 training products/Atlas A3 inference products

Atlas A2 training products/Atlas A2 inference products

Atlas 200I/500 A2 inference products

x

Atlas inference product's AI Core

x

Atlas inference product's Vector Core

x

Atlas training products

x

Function

Sets the operator tiling address for use by the callback function when the template parameter MatmulCallBackFunc (custom callback function) is enabled. This API only needs to be called once.

Prototype

1
__aicore__ inline void SetUserDefInfo(const uint64_t tilingPtr)

Parameters

Table 1 Parameters

Parameter

Input/Output

Description

tilingPtr

Input

Tiling address of the operator

Returns

None

Restrictions

  • If tilingPtr is required in the callback function, this API must be called. If it is not required, this API does not need to be called.
  • This API is not supported when enableMixDualMaster (dual-master mode) is set to true.

Example

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
// User-defined callback function
void DataCopyOut(const __gm__ void *gm, const LocalTensor<int8_t> &co1Local, const void *dataCopyOutParams, const uint64_t tilingPtr, const uint64_t dataPtr);
void CopyA1(const LocalTensor<int8_t> &aMatrix, const __gm__ void *gm, int row, int col, int useM, int useK, const uint64_t tilingPtr, const uint64_t dataPtr);
void CopyB1(const LocalTensor<int8_t> &bMatrix, const __gm__ void *gm, int row, int col, int useK, int useN, const uint64_t tilingPtr, const uint64_t dataPtr);

typedef AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, half> aType; 
typedef AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, half> bType; 
typedef AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, float> cType; 
typedef AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, float> biasType; 
AscendC::Matmul<aType, bType, cType, biasType, CFG_NORM, MatmulCallBackFunc<DataCopyOut, CopyA1, CopyB1>> mm;
REGIST_MATMUL_OBJ(&pipe, GetSysWorkSpacePtr(), mm, &tiling);
uint64_t tilingPtr = reinterpret_cast<uint64_t>(tiling);
mm.SetUserDefInfo(tilingPtr);
mm.SetTensorA(gmA);
mm.SetTensorB(gmB);
mm.IterateAll();