MC2 Operators

Compared with data computation or movement operators, merged compute and communication operators (MC2 operators) fuse collective communication and computing tasks that support pipeline parallelism, to improve performance during operator execution. Typical application scenarios include the Matmul computation and collective communication. For details about how to implement MC2 operators, use the high-level APIs of HCCL. For details about collective communication and related concepts, see HCCL Overview.

The development process of MC2 operators is the same as that of common operators. Note that MC2 operators currently do not support kernel launch or graph (GE) development.

Operator Prototype Definition

Compared with common operators, MC2 operators have the following restrictions when implementing operator prototype definition:

  • Define an attribute indicating the operator communicator name. The communicator is a context for implementing collective communication. It manages corresponding communication entities (for example, NPUs) and resources required for communication.
  • Register the operator as an MC2 operator through MC2, and set the communicator name of the operator through HcclGroup.

For example, the attribute group is the communicator name of the operator. In operator prototype definition, the calling method is as follows:

1
2
3
this->Attr("group").AttrType(REQUIRED).String(); // group is the attribute of the MC2 operator, indicating the communicator name.
...
this->MC2().HcclGroup("group"); // Set group as the communicator of the operator.

Example

Take the AllGatherMatmul custom operator as an example. The operator prototype is defined as follows:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
namespace ops {
    class AllGatherMatmulCustom : public OpDef {
    public:
        explicit AllGatherMatmulCustom(const char *name) : OpDef(name) {
            this->Input("x1")
                    .ParamType(REQUIRED)
                    .DataType({ge::DT_FLOAT16, ge::DT_BF16})
                    .Format({ge::FORMAT_ND, ge::FORMAT_ND})
                    .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
            this->Input("x2")
                    .ParamType(REQUIRED)
                    .DataType({ge::DT_FLOAT16, ge::DT_BF16})
                    .Format({ge::FORMAT_ND, ge::FORMAT_ND})
                    .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND})
                    .IgnoreContiguous();
            this->Input("bias")
                    .ParamType(OPTIONAL)
                    .DataType({ge::DT_FLOAT16, ge::DT_BF16})
                    .Format({ge::FORMAT_ND, ge::FORMAT_ND})
                    .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});

            this->Output("y")
                    .ParamType(REQUIRED)
                    .DataType({ge::DT_FLOAT16, ge::DT_BF16})
                    .Format({ge::FORMAT_ND, ge::FORMAT_ND})
                    .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
            this->Output("gather_out")
                    .ParamType(REQUIRED)
                    .DataType({ge::DT_FLOAT16, ge::DT_BF16})
                    .Format({ge::FORMAT_ND, ge::FORMAT_ND})
                    .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});

            this->Attr("group").AttrType(REQUIRED).String();
            this->Attr("isTransA").AttrType(OPTIONAL).Bool(false);
            this->Attr("isTransB").AttrType(OPTIONAL).Bool(false);
            this->Attr("gatherIndex").AttrType(OPTIONAL).Int(0);
            this->Attr("commTurn").AttrType(OPTIONAL).Int(0);
            this->Attr("rank_size").AttrType(OPTIONAL).Int(8);
            this->Attr("is_gather_out").AttrType(OPTIONAL).Bool(true);

            this->AICore().SetTiling(AllGatherMatmulCustomTilingFunc);
            this->AICore().AddConfig("ascendxxx"); // Replace ascendxxx with the actual Ascend AI Processor model.
            this->MC2().HcclGroup("group");
        }
    };
}