算子使用指导(torch_atb)

  1. 导入torch_atb模块。

    在代码中,通过以下方式导入torch_atb模块:

    import torch_atb

  2. 创建算子对象实例

    • 单算子
      1. 创建参数对象

        根据需要创建的算子,实例化参数结构体,参数接口的定义参考OpParam,下面以linear为例。

        linear_param = torch_atb.LinearParam()
        linear_param.has_bias = False
      2. 创建算子对象
        op = torch_atb.Operation(linear_param)
    • 图算子

      根据设计的图算子结构,通过tensorName设置各个组合OP的tensor之间的对应关系创建图算子。

      用图构造器构图:

      graph = torch_atb.GraphBuilder("Graph") \ # 设置图算子的名称
               .set_input_output(["x", "y", "z"], ["out"]) \ # 指定图的输入输出tensorName
               .add_operation(elewise_add, ["x", "y"], ["add_out"]) \ # 添加图中需要的单算子,及使用的tensor
               .reshape("add_out", lambda shape: [1, shape[0] * shape[1]], "add_out_") \ # 在需要更新输出tensor的shape时,使用lambda表达式动态更新
               .add_operation(elewise_mul, ["add_out_", "z"], ["out"]) \
               .build() # 结束构图、创建图算子

  3. 准备输入数据

    准备输入tensor,确保设备(npu、cpu)、数据类型符合预期。

    x = torch.randn(2, 3, dtype=torch.float16).npu()  
    y = torch.randn(2, 3, dtype=torch.float16).npu()

  4. 执行操作

    使用 forward 方法完成操作,并获取输出:

    outputs = op.forward([x, y]) # 如果使用图算子,则outputs = graph.forward([x, y, z]) 
    torch.npu.synchronize()