Channel Splitting of the Matrix Multiplication Output

Overview

Channel splitting of the matrix multiplication output is also called ChannelSplit. When the format of the C matrix in the Matmul computation result is NZ, the C matrix is stored in fractal mode. For details about the NZ format, see Data Formats. When the physical layout format of the C matrix is NZ and the data type is float, each fractal contains 16 x 16 elements by default, that is, the size of each fractal is 16 x 16. The function of ChannelSplit is to split each 16 x 16 fractal of the C matrix into 16 x 8 fractals, so that the C matrix is stored in 16 x 8 fractals.

Because the size of a float data is 4 bytes, the 16 x 8 fractal is 32-byte aligned on the inner axis. The amount of data on the inner axis is the same as that of the data unit processed by an NPU vector instruction, facilitating subsequent computation. The ChannelSplit function is disabled by default. You can enable this function by setting the isEnableChannelSplit parameter in MatmulConfig to true.

Figure 1 ChannelSplit function

Use Case

This function is used to store the C matrix in the NZ format and float type in 16 x 8 fractals.

Restrictions

To enable the ChannelSplit function, the following conditions must be met:

  • The data layout format of the C matrix is CubeFormat::NZ.
  • The data type of the C matrix is float.
  • The logical memory location of the C matrix is the global memory.

Examples

For a complete operator example, see matmul_channelsplit operator sample.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
// Specify the MatmulConfig template to be obtained and modified.
constexpr static MatmulConfigMode configMode = MatmulConfigMode::CONFIG_NORM;
// Set the template parameter isEnableChannelSplit to true to enable the ChannelSplit function of the MatmulConfig template.
constexpr static MatmulFuncParams funcParamsChannelSplit{
    false, false, false, false, 0, IterateOrder::ORDER_M, ScheduleType::INNER_PRODUCT, true, false, false, false, true/*isEnableChannelSplit*/
};
constexpr static MatmulConfig MM_CFG = GetMMConfig<configMode>(funcParamsChannelSplit);
Matmul<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG> mm;

//Performing general Matmul computation, with the final fractal output size of 16 x 8
REGIST_MATMUL_OBJ(&pipe, GetSysWorkSpacePtr(), mm);
mm.SetTensorA(gm_a);
mm.SetTensorB(gm_b);
mm.SetBias(gm_bias);
mm.IterateAll(gm_c);