Affinity Operator Tuning Strategy

Symptom

When the BERT model is trained on B4, the performance of B4 is obviously lower than that of A100. For the same step, A100 takes 45 seconds, B4 takes 1,000 seconds, and the NPU performance is about 0.045 GPU.

Analysis

  1. In the overlap analysis table, computing accounts for most of the time in the step, while the proportion of free time is very low. Therefore, tuning starts with computing.
    Figure 1 Overlap analysis table
    Figure 2 Viewing the computing operators
  2. Check the time consumption of the operator. The computing time of the IndexPutV2 operator accounts for 75% of the total time consumption. Therefore, the operator needs to be optimized. For details, see Table 2.

    The IndexPutV2 operator is used as an example. The operations of other operators are similar. For details, see Table 1 and Table 2.

    Table 1 Code optimization for the Linear, Reduce_Sum, BatchMatMulV2, RepeatInterleave, and Gatherelement operators

    Operator

    Before Code Optimization

    After Code Optimization

    Description

    Linear

    # # # # Before modification# # #
    model=nn.Linear(K,N)
    input=torch.randn(M,K)
    res=model(input)
    ###After modification###
    Res=torch.addmm(bias,input,weight.t())#: The input is in (M,K) format, and the weight is in (N,K) format---weight.t() is (K,N).

    This tuning solution is provided regardless of scenarios.

    Reduce_Sum

    mask = torch.nn.functional.one_hot(indices, num_classes=self.num_experts).sum(dim=1)
    temp_mask = torch.zeros(indices.shape[0], self.num_experts, device="npu", dtype=torch.bfloat16)
    mask=temp_mask.scatter_(-1,indices,1.0)

    The mask creation mode is changed from onehot+reducesum to zeros+scatter. The local computing time is reduced from 2.2 ms to 0.08 ms, and the total time is reduced by 306 ms.

    BatchMatMulV2

    output = tf.matmul(a, b, transpose_a=True) #  a: [bs, n, 1], b: [n, 1]
    a_ = tf.transpose(a, perm=[0, 2, 1])
    a_=tf.reshape(a_,[-1,a_.shape[2]])
    output=tf.matmul(a_,b,transpose_a=True)#a_:[bs,n],b:[n,1]

    When tf.matmul processes inputs with shapes [b, n, 1] and [n, 1] to produce a [b, 1, 1] output, it often triggers an inefficient BatchMatMul path. To improve performance in both forward and backward passes, reshape or transpose these inputs into a 2D MatMul or dot product format.

    RepeatInterleave

    valid_lens = torch.repeat_interleave(valid_lens, shape[1])
    valid_lens = valid_lens.unsqueeze(-1).expand(-1, shape[1]).reshape(-1)

    Optimize the operator by changing the shape to switch the high-performance branch.

    For the second dimension input, replace shape[1] with a 1D tensor of length 2048, taking the form of torch.tensor([shape[1],shape[1],shape[1], ...,shape[1]]).

    Gatherelement

    pt = logit.gather(1, target).view(-1) + eps
    logpt=torch.log(pt)
    alpha=self.alpha.to(logpt.device)
    alpha_class=alpha.gather(0,target.view(-1))
    pt = logit[torch.arange(logit.size(0)), target.squeeze(1)] + eps
    logpt=torch.log(pt)
    alpha=self.alpha.to(logpt.device)
    alpha_class=torch.index_select(alpha,0,target.view(-1))

    The performance deteriorates when the Gatherelement operator is called on the NPU. Use the torch.index_select function to replace the torch.gather function and call the GatherV2 operator to avoid the issue. Note that the index needs to be modified.

    Table 2 Other operators

    Operator

    Official Link

    IndexPutV2

    https://www.hiascend.com/document/detail/zh/Pytorch/60RC3/ptmoddevg/trainingmigrguide/performance_tuning_0033.html

    MatMul/hcom_allReduce

    https://www.hiascend.com/document/detail/zh/Pytorch/60RC3/ptmoddevg/trainingmigrguide/performance_tuning_0026.html

    Nonzero

    https://www.hiascend.com/document/detail/zh/Pytorch/60RC3/ptmoddevg/trainingmigrguide/performance_tuning_0034.html

    where

    https://www.hiascend.com/document/detail/zh/Pytorch/60RC3/ptmoddevg/trainingmigrguide/performance_tuning_0035.html

    RotaryMul & RotaryMulGrad (fused operator)

    https://www.hiascend.com/document/detail/zh/Pytorch/600/ptmoddevg/trainingmigrguide/performance_tuning_0023.html

    RmsNorm & RmsNormGrad

    https://www.hiascend.com/document/detail/zh/Pytorch/600/ptmoddevg/trainingmigrguide/performance_tuning_0024.html

    ScaledMaskedSoftmax & ScaledMaskedSoftmaxGrad

    https://www.hiascend.com/document/detail/zh/Pytorch/600/ptmoddevg/trainingmigrguide/performance_tuning_0025.html

    MatmulAllReduce

    https://www.hiascend.com/document/detail/zh/Pytorch/600/ptmoddevg/trainingmigrguide/performance_tuning_0026.html

    FlashAttentionScore

    https://www.hiascend.com/document/detail/zh/Pytorch/600/ptmoddevg/trainingmigrguide/performance_tuning_0027.html

    SwiGlu

    https://www.hiascend.com/document/detail/zh/Pytorch/600/ptmoddevg/trainingmigrguide/performance_tuning_0100.html

    Fusion optimizer

    https://www.hiascend.com/document/detail/zh/Pytorch/600/ptmoddevg/trainingmigrguide/performance_tuning_0028.html

    Official document for fused operator replacement

    https://www.hiascend.com/document/detail/zh/Pytorch/60RC3/ptmoddevg/trainingmigrguide/performance_tuning_0023.html

    Official document for affinity operator replacement

    https://www.hiascend.com/document/detail/zh/Pytorch/60RC3/ptmoddevg/trainingmigrguide/performance_tuning_0033.html

    Official document for affinity API replacement

    https://www.hiascend.com/document/detail/zh/Pytorch/60RC3/ptmoddevg/trainingmigrguide/performance_tuning_0036.html