mask类型
 mask类型不是独立特性,因paged attention算子的mask较为复杂,为便于理解,此处单独针对mask类型进行说明。
maskType  | 
硬件类型  | 
维度  | 
备注  | 
|---|---|---|---|
UNDEFINED  | 
不传mask  | 
不传mask  | 
相当于一个全零的mask。  | 
MASK_TYPE_NORM  | 
[batch, 1, max_seq_len] 或 [1, max_seq_len] 或 [max_seq_len, max_seq_len]  | 
倒三角mask。  | 
|
[batch, max_seq_len / 16, 16, 16] 或 [1, max_seq_len / 16, 16, 16]  | 
|||
MASK_TYPE_ALIBI  | 
[batch, num_head, 1, max_seq_len] 或 [num_head, 1, max_seq_len]  | 
alibi mask。  | 
|
[batch * num_head, max_seq_len / 16, 16, 16] 或 [num_head, max_seq_len / 16, 16, 16]  | 
|||
MASK_TYPE_SPEC  | 
[num_tokens, max_seq_len]  | 
并行解码mask。  | 
|
[1, max_seq_len / 16, num_tokens, 16]  | 
|||
MASK_TYPE_MASK_FREE  | 
[1, 8,128,16]  | 
固定为128*128的倒三角mask。  | 
上表中