Affinity Operator Tuning Strategy
Symptom
When the BERT model is trained on B4, the performance of B4 is obviously lower than that of A100. For the same step, A100 takes 45 seconds, B4 takes 1,000 seconds, and the NPU performance is about 0.045 GPU.
Analysis
- In the overlap analysis table, computing accounts for most of the time in the step, while the proportion of free time is very low. Therefore, tuning starts with computing.
Figure 1 Overlap analysis table
Figure 2 Viewing the computing operators
- Check the time consumption of the operator. The computing time of the IndexPutV2 operator accounts for 75% of the total time consumption. Therefore, the operator needs to be optimized. For details, see Table 2.
Table 1 Code optimization for the Linear, Reduce_Sum, BatchMatMulV2, RepeatInterleave, and Gatherelement operators Operator
Before Code Optimization
After Code Optimization
Description
Linear
# # # # Before modification# # # model=nn.Linear(K,N) input=torch.randn(M,K) res=model(input)
###After modification### Res=torch.addmm(bias,input,weight.t())#: The input is in (M,K) format, and the weight is in (N,K) format---weight.t() is (K,N).
This tuning solution is provided regardless of scenarios.
Reduce_Sum
mask = torch.nn.functional.one_hot(indices, num_classes=self.num_experts).sum(dim=1)
temp_mask = torch.zeros(indices.shape[0], self.num_experts, device="npu", dtype=torch.bfloat16) mask=temp_mask.scatter_(-1,indices,1.0)
The mask creation mode is changed from onehot+reducesum to zeros+scatter. The local computing time is reduced from 2.2 ms to 0.08 ms, and the total time is reduced by 306 ms.
BatchMatMulV2
output = tf.matmul(a, b, transpose_a=True) # a: [bs, n, 1], b: [n, 1]
a_ = tf.transpose(a, perm=[0, 2, 1]) a_=tf.reshape(a_,[-1,a_.shape[2]]) output=tf.matmul(a_,b,transpose_a=True)#a_:[bs,n],b:[n,1]
When tf.matmul processes inputs with shapes [b, n, 1] and [n, 1] to produce a [b, 1, 1] output, it often triggers an inefficient BatchMatMul path. To improve performance in both forward and backward passes, reshape or transpose these inputs into a 2D MatMul or dot product format.
RepeatInterleave
valid_lens = torch.repeat_interleave(valid_lens, shape[1])
valid_lens = valid_lens.unsqueeze(-1).expand(-1, shape[1]).reshape(-1)
Optimize the operator by changing the shape to switch the high-performance branch.
For the second dimension input, replace shape[1] with a 1D tensor of length 2048, taking the form of torch.tensor([shape[1],shape[1],shape[1], ...,shape[1]]).
Gatherelement
pt = logit.gather(1, target).view(-1) + eps logpt=torch.log(pt) alpha=self.alpha.to(logpt.device) alpha_class=alpha.gather(0,target.view(-1))
pt = logit[torch.arange(logit.size(0)), target.squeeze(1)] + eps logpt=torch.log(pt) alpha=self.alpha.to(logpt.device) alpha_class=torch.index_select(alpha,0,target.view(-1))
The performance deteriorates when the Gatherelement operator is called on the NPU. Use the torch.index_select function to replace the torch.gather function and call the GatherV2 operator to avoid the issue. Note that the index needs to be modified.