昇腾社区首页
中文
注册

TopKFusionPass

融合模式

该融合规则将TopK或TopKV2节点,基于平台的不同替换为TopKV2/TopKD/TopKV2D算子,或直接拆解为SegmentSort+MultiMerge的组合。详细场景如下。

场景一:TopK/TopKV2算子将被替换为TopKV2算子

场景二:TopK/TopKV2算子被替换为TopKD/TopKV2D算子

场景三:TopK/TopKV2算子将被拆分为SegmentSort+MultiMerge的组合

使用约束

  • 不支持TopK算子的attr.sorted=false的场景。
  • Atlas 200/300/500 推理产品 满足下列条件时,TopK/TopKV2算子将被拆分为SegmentSort+MultiMerge的组合。
    • 输入k为const tensor。
    • 输入x排序轴大小和整个tensor大小一致。
    • 总排序轴长度大于7936。
    • k值大于4096。
  • Atlas 200/300/500 推理产品 ,且输入k非const tensor时,TopK/TopKV2算子会替换为TopKV2算子。
  • 除以上两种场景外,开启该规则,TopK/TopKV2算子会替换为TopKD/TopKV2D算子。

支持的型号

该Pass的有效性依赖于运行的产品是否支持相应算子类型TopK/TopKV2,具体信息请参考“CANN算子规格说明”章节。