GraphBuilder
创建GraphBuilder
- 功能
创建GraphBuilder,用于后续的构图操作。
- 原型
torch_atb.GraphBuilder (graphName: string) -> torch_atb.GraphBuilder
- 参数
图op的名称,类型为String。
- 返回值
GraphBuilder
- 注意
调用接口失败,会抛出异常终止运行。
- 使用示例
1
builder = torch_atb.Builder("Graph")
设置计算图的输入
- 功能
设置计算图的输入。
- 原型
GraphBuilder.AddInput (name: string) -> torch_atb.GraphBuilder
- 参数
输入的名称,类型为String。
- 返回值
string
- 注意
调用接口失败,会抛出异常终止运行。
- 使用示例
1 2
x = builder.add_input("x") y = builder.add_input("y")
设置计算图的节点
- 功能
设置计算图的节点。
- 原型
GraphBuilder.AddNode (inputs: vector<string>, param: xxxParam) -> torch_atb.GraphBuilder
- 参数
输入的名称,使用torch_atb创建的算子的参数。
- 返回值
GraphNode
- 注意
调用接口失败,会抛出异常终止运行。
- 使用示例
1 2 3 4
# 创建并配置算子参数 elewise_add = torch_atb.ElewiseParam() elewise_add.elewise_type = torch_atb.ElewiseParam.ElewiseType.ELEWISE_ADD layer1 = builder.add_node([x, y], elewise_add)
设置计算图的输出
- 功能
标记传入tensor为计算图输出。
- 原型
GraphBuilder.MarkOutput (outTensor: string) -> torch_atb.GraphBuilder
- 参数
tensor的名称。
- 注意
调用接口失败,会抛出异常终止运行。
- 使用示例
1
builder.mark_output(layer2.get_output(0))
张量转置
- 功能
改变tensor的shape。
- 原型
GraphBuilder.Reshape (string, std::function, string) -> torch_atb.GraphBuilder
- 参数
tensor的名称。
- 注意
调用接口失败,会抛出异常终止运行。
- 使用示例
1 2 3 4 5
layer1 = builder.add_node([x, y], elewise_add) add_out = layer1.get_output(0) builder.reshape(add_out, lambda shape: [1, shape[0] * shape[1]], "add_out_") # 转置后使用"add_out_"将tensor与其他node绑定 layer2 = builder.add_node(["add_out_", z], elewise_mul)
构建计算图
- 功能
创建图op。
- 原型
GraphBuilder.Build () -> torch_atb.GraphBuilder
- 返回值
atb::OperationWrapper
- 注意
调用接口失败,会抛出异常终止运行。
- 使用示例
1 2 3
builder = torch_atb.Builder("Graph") # 省略添加输入,节点,标记输出过程 Graph = builder.build()
设置实际执行流
- 功能
设置实际执行流。
- 原型
GraphBuilder.SetExecuteStreams (executeStreams: std::vector<std::uintptr_t>) -> torch_atb.GraphBuilder
- 参数
实际执行流(aclStream)列表
- 注意
调用接口失败,会抛出异常终止运行。
- 使用示例
1 2 3 4 5 6 7 8 9 10 11
def create_streams(): streams = [] for i in range(2): stream, ret = acl.rt.create_stream() if ret != ACL_SUCCESS: exit(0) streams.append(stream) return streams builder = torch_atb.Builder("Graph") builder.set_execute_streams(create_streams())
父主题: 组图接口