mask类型

mask类型不是独立特性,因paged attention算子的mask较为复杂,为便于理解,此处单独针对mask类型进行说明。

maskType

硬件类型

维度

备注

UNDEFINED

不传mask

不传mask

相当于一个全零的mask。

MASK_TYPE_NORM

Atlas 800I A2 推理产品/Atlas A2 训练系列产品

Atlas A3 推理系列产品/Atlas A3 训练系列产品

[batch, 1, max_seq_len] 或 [1, max_seq_len] 或 [max_seq_len, max_seq_len]

倒三角mask。

Atlas 推理系列产品

[batch, max_seq_len / 16, 16, 16] 或 [1, max_seq_len / 16, 16, 16]

MASK_TYPE_ALIBI

Atlas 800I A2 推理产品/Atlas A2 训练系列产品

Atlas A3 推理系列产品/Atlas A3 训练系列产品

[batch, num_head, 1, max_seq_len] 或 [num_head, 1, max_seq_len]

alibi mask。

Atlas 推理系列产品

[batch * num_head, max_seq_len / 16, 16, 16] 或 [num_head, max_seq_len / 16, 16, 16]

MASK_TYPE_SPEC

Atlas 800I A2 推理产品/Atlas A2 训练系列产品

Atlas A3 推理系列产品/Atlas A3 训练系列产品

[num_tokens, max_seq_len]

并行解码mask。

Atlas 推理系列产品

[1, max_seq_len / 16, num_tokens, 16]

上表中Atlas 推理系列产品上max_seq_len应16对齐,且维度描述中的除法均为ceil div。