在代码中,通过以下方式导入torch_atb模块:
import torch_atb
根据需要创建的算子,实例化参数结构体,参数接口的定义参考OpParam,下面以linear为例。
linear_param = torch_atb.LinearParam() linear_param.has_bias = False
op = torch_atb.Operation(linear_param)
根据设计的图算子结构,通过tensorName设置各个组合OP的tensor之间的对应关系创建图算子。
用图构造器构图:
graph = torch_atb.GraphBuilder("Graph") \ # 设置图算子的名称 .set_input_output(["x", "y", "z"], ["out"]) \ # 指定图的输入输出tensorName .add_operation(elewise_add, ["x", "y"], ["add_out"]) \ # 添加图中需要的单算子,及使用的tensor .reshape("add_out", lambda shape: [1, shape[0] * shape[1]], "add_out_") \ # 在需要更新输出tensor的shape时,使用lambda表达式动态更新 .add_operation(elewise_mul, ["add_out_", "z"], ["out"]) \ .build() # 结束构图、创建图算子
准备输入tensor,确保设备(npu、cpu)、数据类型符合预期。
x = torch.randn(2, 3, dtype=torch.float16).npu() y = torch.randn(2, 3, dtype=torch.float16).npu()
使用 forward 方法完成操作,并获取输出:
outputs = op.forward([x, y]) # 如果使用图算子,则outputs = graph.forward([x, y, z]) torch.npu.synchronize()