PromptFlashAttentionDPaddingFusionPass

Description

When the PFA operator is called, if the D axis of the QKV input by the PFA operator does not meet the 16-aligned requirement, insert the Pad operator before the PFA operator to pad the D axis to a multiple of 16 and insert the Slice operator after the PFA operator to reshape the scale.

Restrictions

Only the scenario where the input D axis of PromptFlashAttention is not 16-aligned is supported.

Availability