FlashAttention Operator Performance Optimization

Case Study

The FlashAttentionScoreGrad operator in this case is used to compute the reverse output of attention in the training scenario, that is, backpropagation of the FlashAttentionScore operator.

The forward propagation formula for attention is as follows:

For convenience, the formula can be represented using variables S and P:

Then the attention backpropagation formula is as follows:

The following figure shows the computation process.

Figure 1 Operator computation process

Based on the implementation of FlashAttention backpropagation, the overall computation process is as follows. It is provided as additional information for users who are interested in the algorithm of this operator.

  1. Recompute the softmax result p of the FlashAttention process. The computaion result is stored in the UB.

  2. Compute dp. This includes matmul computation and dropout computation. In matmul computation, the left matrix is dy, and the right matrix is the transposed value.

  3. Compute ds. In this computation, the input parameter of FlashSoftmaxGrad computation is dy, and the forward output is attention_in. The result is subtracted from dp, and then multiplied by p to obtain ds.

  4. Compute dq. In this computation, matmul computation is performed on the ds result and key, and the result is multiplied by scale to obtain dq.

  5. Compute dk. In this computation, matmul computation is performed on the transposed ds result and query, and the result is multiplied by scale to obtain dk.

  6. Compute dv. In this computation, drop computation is performed on the result of p, and then matmul computation is performed on the transposed result and dy.

In this case, verification is conducted based on the Atlas A2 training products / Atlas A2 inference products , with two scenarios as an example. In the first causal scenario, the input dimension information is B=1, N1=12, N2=12, S1=6144, S2=6144, D=128. The causal scenario means the shape of atten_mask is lower triangle, as shown in Figure 2. In the second scenario, the input dimension information is B=24, N1=5, N2=5, S1=9216, S2=9216, D=64, without atten_mask and drop_mask input. The optimization methods include tiling size adjustment, inter-core load balancing, CV pipeline parallelism, MTE2 pipeline optimization, and FixPipe pipeline optimization.

Figure 2 atten_mask shape in the causal scenario

Obtaining Profile Data

The pipeline analysis tools include the CAModel and profiling tools. The first is to analyze the ratio of each pipeline from the profile data generated by the profiling tool. The second is to analyze the parallelism of each pipeline from the dotting chart generated by the CAModel tool.

Analyzing Main Bottlenecks

The performance bottlenecks can be determined by observing and analyzing the pipeline chart and profile data. Different bottlenecks may occur in different phases of the optimization process. Therefore, continuous optimization is required to achieve optimal performance.

  • According to the optimization experience, there may be unnecessary performance overheads between cycles. More cycles represent poorer performance. As long as the maximum space limit of the UB is not exceeded, larger UB tiled blocks result in fewer cycles. In the operator, the InitBuffer API is used to allocate the UB buffer size.
    1
    2
    3
    pipe->InitBuffer(ubBuffer, 120 * 1024);  
    pipe->InitBuffer(tmpBuffer, 30 * 1024);  
    pipe->InitBuffer(vecClc3, 8 * 1024);
    

    The second parameter of InitBuffer indicates the buffer size. The sum of all buffer sizes is the total occupied space. Here, 120 x 1024 + 30 x 1024 + 8 x 1024 = 158 KB, which is less than the UB size, so the UB space is not fully used.

  • In the following figure, the green waves indicate the Vector pipelines, and the orange waves indicate the Cube pipelines. You can see that the pipelines on both sides have large gaps. The CV pipelines are not parallel largely. Therefore, CV pipeline optimization needs to be considered.
    Figure 3 Operator pipelines before optimization
  • In the causal scenario, inter-core distribution may be uneven. After the atten_mask is used, the red part (see the following figure) needs to be computed by the operator, and the green part does not need to be computed. If not divided based on the number of basic blocks, cores are divided based on the size of the first axis (8 rows), that is, evenly distributed nine cores. Each core computes ceil(8/9) = 1 row. In this case, the first core needs to compute only one basic block. However, the eighth core needs to compute eight basic blocks, causing obvious load imbalance. Therefore, you need to evenly distribute the red blocks to multiple cores for computation.
    Figure 4 atten_mask shape in the causal scenario
  • Profile data in scenario 1 is as follows. aic_fixpipe_ratio is extremely high, and FixPipe bound may exist.
    Figure 5 Profile data in scenario 1
  • Profile data in scenario 2 is as follows. mte2_ratio is high, and MTE2 bound may exist.
    Figure 6 Profile data in scenario 2

Developing Optimization Solutions

  • Optimization 1: Adjustment of tiling basic blocks

    Under the premise of meeting the size restrictions of the UB space, larger tiling basic blocks are better. The following figure shows the computation based on (64, 128) tiling before optimization. A total of 32 computation cycles are required.

    Figure 7 Computing basic blocks before optimization

    Considering that the UB space is not used up, the basic block size is changed to (128, 128). As shown in the following figure, only 16 computation cycles are required after optimization. Tiling contributes to doubled operator performance.

    Figure 8 Computing basic blocks after optimization
  • Optimization 2: CV pipeline parallelism

    The Cube computation of the FAG operator is faster than and depends on the Vector computation. In addition, to reduce the number of CV communication times, the cache mechanism is used to enable the matmul to compute multiple blocks in advance. The cache mechanism refers to that the matmul computes multiple basic blocks at a time and caches them to the GM. In the following code, the values of singleCoreM and singleCoreN set by SetTail are multiples of BaseM and BaseN, respectively. That is, matmul initiates the computation of multiple basic blocks at a time to cache the matmul result. The Vector obtains the matmul result multiple times.

    1
    2
    3
    4
    mm3.SetTail(s2CvExtend, -1, preS1Extend);  
    mm3.SetTensorA(mulWorkSpaceGm[pingpongIdx * coreNum * cubeBaseMN + cBlockIdx * cubeBaseMN], true); 
    mm3.SetTensorB(queryGm[mm2aTensorOffsetCv]);  
    mm3.template IterateAll<false>(dkWorkSpaceGm[bTensorOffsetCv], true);
    
    Figure 9 Pipeline of the mm1, mm2, and mm3 caches

    The preceding figure shows the pipeline chart of the mm1, mm2, and mm3 caches, with improved parallelism and reduced CV interval, improving operator performance.

    Figure 10 Insertion of Vector computation into intervals of Cube pipelines such as Vector

    After the optimization based on the mm1, mm2, and mm3 caches, the next round of Vector computation is inserted into the interval of Cube pipelines such as the current round of Vector computation. This improves the parallelism between Vector pipelines and Cube pipelines, so that the Vector computation is more intensive in the pipeline chart. The following shows the pseudo-code of the original computation process and the pseudo-code in which the next round of Vector computation inserted into the CV intervals

    1
    2
    3
    4
    5
    6
    7
    8
    // Pseudo-code of the original computation process
    // mm1 computation;
    dropout();
    Sub();
    // mm2 computation;
    Softmax();
    AttenMask();
    ...
    
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    // Insert the pseudo-code of the next round of Vector computation into the interval of Cube pipelines such as Vector.
    // mm1 computation;
    dropout();
    Sub();
    dropout(); // Next round of Vector computation
    Sub();  // Next round of Vector computation
    // mm2 computation;
    Softmax();
    AttenMask();
    ...
    
  • Optimization 3: Load balancing for each core
    Figure 11 Computation amount of each core before optimization in the causal scenario
    Figure 12 Computation amount of each core after optimization in the causal scenario

    Ensure that each core computes the same size of data, with balanced load. Figure 11 shows the core division before optimization and the computation amount of each core. Cores are divided based on the size of the first axis 8 (rows), that is, evenly distributed nine cores. Each core computes ceil(8/9)=1 row. The first core computes only one basic block, but the eighth core computes eight basic blocks. In Figure 12, there are a total of 36 basic blocks as marked in red, which are evenly distributed to each core. Each core computes four basic blocks, doubling the performance.

  • Optimization 4: FixPipe optimization

    According to the collected profile data, the Cube FixPipe ratio is 81%, which is severe bound. According to the information printed by the CAModel tool, there are many abnormal 128-byte transfers. After checking the code, it is found that the workspace address is not 512-byte aligned.

    Figure 13 Profile data before optimization in scenario 1

    In the code implementation, SetGlobalBuffer is used to set the start address of the workspace. The start address has to be 512-byte aligned to avoid low transfer efficiency. For details, see Using 512-Byte Alignment for the GM Address. In the following code, ADDR_ALIGN_SIZE is 512.

     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    // init workspace address  
    syncGlobal.SetGlobalBuffer((__gm__ int32_t*)workspace);  
    uint64_t workspaceOffsets = SYNC_GLOBAL_WORKSPACE_SIZE;  
    dqWorkSpaceGm.SetGlobalBuffer((__gm__ float*)workspace + workspaceOffsets / sizeof(T2));  
    workspaceOffsets = (workspaceOffsets + qPostBlockTotal * sizeof(float) + ADDR_ALIGN_SIZE) / ADDR_ALIGN_SIZE * ADDR_ALIGN_SIZE;  dkWorkSpaceGm.SetGlobalBuffer((__gm__ float*)workspace + workspaceOffsets / sizeof(T2));  
    workspaceOffsets = (workspaceOffsets + kvPostBlockTotal * sizeof(float) + ADDR_ALIGN_SIZE) / ADDR_ALIGN_SIZE * ADDR_ALIGN_SIZE;  dvWorkSpaceGm.SetGlobalBuffer((__gm__ float*)workspace + workspaceOffsets / sizeof(T2));  
    workspaceOffsets = (workspaceOffsets + kvPostBlockTotal * sizeof(float) + ADDR_ALIGN_SIZE) / ADDR_ALIGN_SIZE * ADDR_ALIGN_SIZE; 
    // matmul1 and matmul2 workspace size  
    matmulWorkspaceSize = cubeBaseMN * sizeof(float); 
    mm1WorkspaceGm.SetGlobalBuffer((__gm__ T2*)(workspace + workspaceOffsets + cBlockIdx * matmulWorkspaceSize));  mm2WorkspaceGm.SetGlobalBuffer((__gm__ T2*)(workspace + workspaceOffsets + coreNum * matmulWorkspaceSize + cBlockIdx * matmulWorkspaceSize));   // drop workspace offset  
    workspaceOffsets = (workspaceOffsets + coreNum * cubeBaseMN * sizeof(float) * INPUT_NUMS + ADDR_ALIGN_SIZE) / ADDR_ALIGN_SIZE * ADDR_ALIGN_SIZE;  
    dropWorkSpaceGm.SetGlobalBuffer((__gm__ T1*)workspace + workspaceOffsets / sizeof(T1));   
    // mul workspace offset  
    workspaceOffsets = (workspaceOffsets + coreNum * cubeBaseMN * sizeof(half) * 2 + ADDR_ALIGN_SIZE) / ADDR_ALIGN_SIZE * ADDR_ALIGN_SIZE;  
    mulWorkSpaceGm.SetGlobalBuffer((__gm__ T1*)workspace + workspaceOffsets / sizeof(T1));
    

    After the code is modified, the workspace address is 512-byte aligned, so the FixPipe time is halved.

    Figure 14 Profile data after optimization in scenario 1
  • Optimization 5: MTE2 optimization

    According to the following profile data and pipeline chart, the MTE2 bound is used and the time of some MTE2 transfers is abnormal.

    Figure 15 Profile data in scenario 2
    Figure 16 Pipeline chart in scenario 2

    After the input data format is changed from BSH to BNSD, data is transferred continuously without address skipping. The transfer efficiency is doubled, and the transfer duration is reduced by half.

Verifying Optimization Benefits

  • Adjustment of tiling basic blocks: Theoretically, larger Vector tiling blocks bring fewer computation and transfer cycles. In addition, the transfer bandwidth and Vector computing power can be fully utilized. After the basic block size is increased from (64, 128) to (128, 128), the performance is doubled, showing a consistent result with the theoretical analysis.
  • CV pipeline parallelism: The CV pipeline mask time is the improved performance, which satisfies the expected benefits.
  • Inter-core load balancing: The reduced multiple of core computation after optimization is considered as the expected improvement. In this example, the core computes eight blocks before the optimization and four blocks after the optimization. Therefore, the actual performance is doubled, which is the expected benefit.
  • FixPipe optimization: According to the profile data, the FixPipe ratio is 0.8 before optimization and 0.55 after optimization. The test result shows that the operator performance is improved by 45%, which is consistent with the theoretical analysis.
  • MTE2 optimization: According to the profile data, the MTE2 ratio is 0.52 before optimization and reduced by half after optimization. The test result shows that the operator performance is improved by 30%, which is consistent with the theoretical analysis.

Summary

In the fusion operator scenario, you can refer to these optimizations.