昇腾社区首页
中文
注册

GraphBuilder

创建GraphBuilder

  • 功能

    创建GraphBuilder,用于后续的构图操作。

  • 原型
    torch_atb.GraphBuilder (graphName: string) -> torch_atb.GraphBuilder
  • 参数

    图op的名称,类型为String。

  • 返回值

    GraphBuilder

  • 注意

    调用接口失败,会抛出异常终止运行。

  • 使用示例
    1
    builder = torch_atb.Builder("Graph")
    

设置计算图的输入

  • 功能

    设置计算图的输入。

  • 原型
    GraphBuilder.AddInput (name: string) -> torch_atb.GraphBuilder
  • 参数

    输入的名称,类型为String。

  • 返回值

    string

  • 注意

    调用接口失败,会抛出异常终止运行。

  • 使用示例
    1
    2
    x = builder.add_input("x")
    y = builder.add_input("y")
    

设置计算图的节点

  • 功能

    设置计算图的节点。

  • 原型
    GraphBuilder.AddNode (inputs: vector<string>, param: xxxParam) -> torch_atb.GraphBuilder
  • 参数

    输入的名称,使用torch_atb创建的算子的参数。

  • 返回值

    GraphNode

  • 注意

    调用接口失败,会抛出异常终止运行。

  • 使用示例
    1
    2
    3
    4
    # 创建并配置算子参数
    elewise_add = torch_atb.ElewiseParam()
    elewise_add.elewise_type = torch_atb.ElewiseParam.ElewiseType.ELEWISE_ADD
    layer1 = builder.add_node([x, y], elewise_add)
    

设置计算图的输出

  • 功能

    标记传入tensor为计算图输出。

  • 原型
    GraphBuilder.MarkOutput (outTensor: string) -> torch_atb.GraphBuilder
  • 参数

    tensor的名称。

  • 注意

    调用接口失败,会抛出异常终止运行。

  • 使用示例
    1
    builder.mark_output(layer2.get_output(0))
    

张量转置

  • 功能

    改变tensor的shape。

  • 原型
    GraphBuilder.Reshape (string, std::function, string) -> torch_atb.GraphBuilder
  • 参数

    tensor的名称。

  • 注意

    调用接口失败,会抛出异常终止运行。

  • 使用示例
    1
    2
    3
    4
    5
    layer1 = builder.add_node([x, y], elewise_add)
    add_out = layer1.get_output(0)
    builder.reshape(add_out, lambda shape: [1, shape[0] * shape[1]], "add_out_")
    # 转置后使用"add_out_"将tensor与其他node绑定
    layer2 = builder.add_node(["add_out_", z], elewise_mul)
    

构建计算图

  • 功能

    创建图op。

  • 原型
    GraphBuilder.Build () -> torch_atb.GraphBuilder
  • 返回值
    atb::OperationWrapper
  • 注意

    调用接口失败,会抛出异常终止运行。

  • 使用示例
    1
    2
    3
    builder = torch_atb.Builder("Graph")
    # 省略添加输入,节点,标记输出过程
    Graph = builder.build()
    

设置实际执行流

  • 功能

    设置实际执行流。

  • 原型
    GraphBuilder.SetExecuteStreams (executeStreams: std::vector<std::uintptr_t>) -> torch_atb.GraphBuilder
  • 参数

    实际执行流(aclStream)列表

  • 注意

    调用接口失败,会抛出异常终止运行。

  • 使用示例
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    def create_streams():
        streams = []
        for i in range(2):
            stream, ret = acl.rt.create_stream()
            if ret != ACL_SUCCESS:
                exit(0)
            streams.append(stream)
        return streams
    
    builder = torch_atb.Builder("Graph")
    builder.set_execute_streams(create_streams())