昇腾社区首页
中文
注册

通算融合算子

相比于一般的计算或搬运类算子,通算融合算子是融合集合通信任务和计算任务的算子,在算子执行过程中,计算和通信任务可以部分流水并行,以便提升性能。典型的应用场景如Matmul计算+集合通信。通算融合类算子的实现,请参考并使用Hccl高阶API。关于更多集合通信的内容和相关概念请参考“HCCL概述”章节。

通算融合算子的开发过程与一般算子相同,但请注意,当前通算融合算子暂不支持Kernel直调入图(GE图)开发

算子原型定义

相比于一般算子,通算融合算子在实现算子原型定义时,有如下约束:

  • 必须定义一个表示算子通信域名称的属性。通信域是集合通信执行的上下文,管理对应的通信实体(例如一个NPU就是一个通信实体)和通信所需的资源。
  • 必须通过MC2接口注册该算子为通算融合算子,并通过HcclGroup接口配置该算子的通信域名称。

以算子属性"group"为例,"group"为该算子的通信域名称,在算子原型定义中,调用方式如下:

1
2
3
this->Attr("group").AttrType(REQUIRED).String(); // "group"为通算融合算子的属性,表示通信域名称
...
this->MC2().HcclGroup("group"); // 将"group"配置为该算子的通信域

重执行

为避免执行通信任务的环境中硬件闪断导致发生通信中断,通算融合算子可通过配置编译宏与环境变量,开启重执行能力。通算融合算子开启重执行后,AI CPU在检测到环境异常时,通过下图示意的机制,通知AI Core重新下发通信任务,避免由于硬件闪断造成的通信中断,提升通信稳定性。

当前,该能力的支持情况如下:

Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件 不支持通算融合算子的重执行。

Atlas A3 训练系列产品/Atlas A3 推理系列产品 支持通算融合算子的重执行。

图1 通信任务重执行机制

开启重执行的条件如下:

  • 通算融合算子的输出内存地址和输入内存地址不相同。
  • 算子编译时,配置编译宏AICORE_EXCEPTION_RESTART,如下所示。具体的编译宏配置阶段和方式请参考支持自定义编译选项
    1
    add_ops_compile_options(ALL OPTIONS -DAICORE_EXCEPTION_RESTART)
    
  • 配置HCCL重执行环境变量HCCL_OP_RETRY_ENABLE,开启重执行的检测和上报能力,该环境变量的说明请参考环境变量参考“执行相关 > 集合通信 > HCCL_OP_RETRY_ENABLE。请在算子执行前设置该环境变量,具体配置如下:
    1
    2
    # server内L0和server间L1均需配置为1, 不支持跨超节点,L2配置为0。
    export HCCL_OP_RETRY_ENABLE="L0:1, L1:1, L2:0" 
    

    注意,开启重执行后,若AI Core第一次下发通信任务后通信中断,默认只重执行一次。若需修改重执行次数或重传间隔时间,请参考环境变量参考“执行相关 > 集合通信 > HCCL_OP_RETRY_PARAMS

调用示例

以AllGatherMatmul自定义算子为例,算子原型定义如下。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
namespace ops {
    class AllGatherMatmulCustom : public OpDef {
    public:
        explicit AllGatherMatmulCustom(const char *name) : OpDef(name) {
            this->Input("x1")
                    .ParamType(REQUIRED)
                    .DataType({ge::DT_FLOAT16, ge::DT_BF16})
                    .Format({ge::FORMAT_ND, ge::FORMAT_ND})
                    .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
            this->Input("x2")
                    .ParamType(REQUIRED)
                    .DataType({ge::DT_FLOAT16, ge::DT_BF16})
                    .Format({ge::FORMAT_ND, ge::FORMAT_ND})
                    .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND})
                    .IgnoreContiguous();
            this->Input("bias")
                    .ParamType(OPTIONAL)
                    .DataType({ge::DT_FLOAT16, ge::DT_BF16})
                    .Format({ge::FORMAT_ND, ge::FORMAT_ND})
                    .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});

            this->Output("y")
                    .ParamType(REQUIRED)
                    .DataType({ge::DT_FLOAT16, ge::DT_BF16})
                    .Format({ge::FORMAT_ND, ge::FORMAT_ND})
                    .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
            this->Output("gather_out")
                    .ParamType(REQUIRED)
                    .DataType({ge::DT_FLOAT16, ge::DT_BF16})
                    .Format({ge::FORMAT_ND, ge::FORMAT_ND})
                    .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});

            this->Attr("group").AttrType(REQUIRED).String();
            this->Attr("isTransA").AttrType(OPTIONAL).Bool(false);
            this->Attr("isTransB").AttrType(OPTIONAL).Bool(false);
            this->Attr("gatherIndex").AttrType(OPTIONAL).Int(0);
            this->Attr("commTurn").AttrType(OPTIONAL).Int(0);
            this->Attr("rank_size").AttrType(OPTIONAL).Int(8);
            this->Attr("is_gather_out").AttrType(OPTIONAL).Bool(true);

            this->AICore().SetTiling(AllGatherMatmulCustomTilingFunc);
            this->AICore().AddConfig("ascendxxx"); // ascendxxx请修改为对应的昇腾AI处理器型号。
            this->MC2().HcclGroup("group");
        }
    };
}