mask类型不是独立特性,因paged attention算子的mask较为复杂,为便于理解,此处单独针对mask类型进行说明。
maskType |
硬件类型 |
维度 |
备注 |
---|---|---|---|
UNDEFINED |
不传mask |
不传mask |
相当于一个全零的mask。 |
MASK_TYPE_NORM |
[batch, 1, max_seq_len] 或 [1, max_seq_len] 或 [max_seq_len, max_seq_len] |
倒三角mask。 |
|
[batch, max_seq_len / 16, 16, 16] 或 [1, max_seq_len / 16, 16, 16] |
|||
MASK_TYPE_ALIBI |
[batch, num_head, 1, max_seq_len] 或 [num_head, 1, max_seq_len] |
alibi mask。 |
|
[batch * num_head, max_seq_len / 16, 16, 16] 或 [num_head, max_seq_len / 16, 16, 16] |
|||
MASK_TYPE_SPEC |
[num_tokens, max_seq_len] |
并行解码mask。 |
|
[1, max_seq_len / 16, num_tokens, 16] |
上表中