Triangle倒三角

概述

LLM是基于only-decoder的模式,在进行self-attention的计算时会采用sequence-mask的方法,目的是为了使得decoder不能看见未来的信息。也就是对于一个序列中的第i个token,解码的时候只能够依靠i时刻之前(包括i)的的输出,而不能依赖于i时刻之后的输出。因此通常采取一个遮盖的方法(Mask)使得其在计算self-attention的时候只用i个时刻之前的token进行计算。

self-attention的计算会将mask的部分也计算进去,针对这一场景,提出了倒三角attention加速算法,避免attention mask中上三角位置中无需计算的block参与计算,从而减少计算量。计算过程如图所示:

功能使用

倒三角功能的定义位于“ascendspeed/model/triangle_attention.py”文件中的“”TriangleAttention与“”函数。

开关定义参见“ascendspeed/arguments.py”文件中的“_add_distributed_args”函数。

通过配置--triangle-attn来使能。

训练脚本参数配置示例,加粗部分为倒三角功能配置参数示例:

python -m torch.distributed.launch $DISTRIBUTED_ARGS \
   pretrain_llama.py \
   --DDP-impl local \
   --num-layers 30 \
   --hidden-size 6656 \
   --ffn-hidden-size 17920 \
   --num-attention-heads 52 \
   --micro-batch-size 1 \
   --global-batch-size 32 \
   --seq-length 2048 \
   --triangle-attn

功能初始化示例:

def __init__(self, block_size=1024, masked_softmax_func=None,dropout_func=None): 

参数说明:

代码示例:

class LlamaParallelAttention(MegatronModule):
    """Parallel self-attention layer abstract class.
    Self-attention layer takes input with size [b, s, h]
    and returns output of the same size.
    """
    """
    省略
    """
    def __init__(self, init_method,
                 output_layer_init_method, layer_number,
                 attention_type=AttnType.self_attn,
                 attn_mask_type=AttnMaskType.causal,
                 checkpoint_activations=False,
                 checkpoint_policy=None):
        self.use_triangle_attn = args.triangle_attention
        if self.use_triangle_attn:
            self.triangle_attn = Triangle_attention(block_size=1024,
                                                    masked_softmax_func=self.scale_mask_softmax,
                                                    dropout_func=None) 

在attention的前向函数中的位置编码后插入以下代码,q k v layer的shape为bachsize, head_num,sequence,hidden dim,输出的维度为bachsize,sequence,head_num * hidden dim。

if self.use_triangle_attn and layer_past is None:
    context_layer = self.triangle_attn(query_layer, key_layer, value_layer, attention_mask)
    output, _ = self.dense(context_layer)
    return output