SetBatchInfoForNormal

Function

Sets the M, N, and K axes and the batch sizes of matrix A and matrix B. In the NORMAL layout, before calling IterateBatch or IterateNBatch, you need to call this API to set the M, N, and K axes of matrix A and matrix B in the tiling implementation on the host.

Prototype

1
int32_t SetBatchInfoForNormal(int32_t batchA, int32_t batchB, int32_t m, int32_t n, int32_t k)

Parameters

Table 1 Parameters

Parameter

Input/Output

Description

batchA

Input

Number of batches of matrix A

batchB

Input

Number of batches of matrix B

m

Input

M-axis information of matrix A

n

Input

N-axis information of matrix B

k

Input

K-axis information of matrix A or B

Returns

-1: setting failed; 0: setting succeeded.

Restrictions

In the NORMAL layout, before calling IterateBatch or IterateNBatch, you need to call this API to set the M, N, and K axes of matrix A and matrix B in the tiling implementation on the host.

Example

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
matmul_tiling::MultiCoreMatmulTiling tiling(ascendcPlatform);   
int32_t M = 32;
int32_t N = 256;
int32_t K = 64;
tiling->SetDim(1);
tiling->SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT16);
tiling->SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT16);
tiling->SetCType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT);
tiling->SetBiasType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT);
tiling->SetShape(M, N, K);
tiling->SetOrgShape(M, N, K);
tiling->SetBias(true);
tiling->SetBufferSpace(-1, -1, -1);

constexpr int32_t BATCH_NUM = 3;
tiling->SetBatchInfoForNormal(BATCH_NUM, BATCH_NUM, M, N, K);  // Set the matrix layout.
tiling->SetBufferSpace(-1, -1, -1);

optiling::TCubeTiling tilingData;
int ret = tiling.GetTiling(tilingData);