GraphBuilder
创建GraphBuilder
- 功能
创建GraphBuilder,用于后续的构图操作。
 - 原型
torch_atb.GraphBuilder (graphName: string) -> torch_atb.GraphBuilder
 - 参数
图op的名称,类型为String。
 - 返回值
GraphBuilder
 - 注意
调用接口失败,会抛出异常终止运行。
 - 使用示例
1builder = torch_atb.Builder("Graph")
 
设置计算图的输入
- 功能
设置计算图的输入。
 - 原型
GraphBuilder.AddInput (name: string) -> torch_atb.GraphBuilder
 - 参数
输入的名称,类型为String。
 - 返回值
string
 - 注意
调用接口失败,会抛出异常终止运行。
 - 使用示例
1 2
x = builder.add_input("x") y = builder.add_input("y")
 
设置计算图的节点
- 功能
设置计算图的节点。
 - 原型
GraphBuilder.AddNode (inputs: vector<string>, param: xxxParam) -> torch_atb.GraphBuilder
 - 参数
输入的名称,使用torch_atb创建的算子的参数。
 - 返回值
GraphNode
 - 注意
调用接口失败,会抛出异常终止运行。
 - 使用示例
1 2 3 4
# 创建并配置算子参数 elewise_add = torch_atb.ElewiseParam() elewise_add.elewise_type = torch_atb.ElewiseParam.ElewiseType.ELEWISE_ADD layer1 = builder.add_node([x, y], elewise_add)
 
设置计算图的输出
- 功能
标记传入tensor为计算图输出。
 - 原型
GraphBuilder.MarkOutput (outTensor: string) -> torch_atb.GraphBuilder
 - 参数
tensor的名称。
 - 注意
调用接口失败,会抛出异常终止运行。
 - 使用示例
1builder.mark_output(layer2.get_output(0))
 
张量转置
- 功能
改变tensor的shape。
 - 原型
GraphBuilder.Reshape (string, std::function, string) -> torch_atb.GraphBuilder
 - 参数
tensor的名称。
 - 注意
调用接口失败,会抛出异常终止运行。
 - 使用示例
1 2 3 4 5
layer1 = builder.add_node([x, y], elewise_add) add_out = layer1.get_output(0) builder.reshape(add_out, lambda shape: [1, shape[0] * shape[1]], "add_out_") # 转置后使用"add_out_"将tensor与其他node绑定 layer2 = builder.add_node(["add_out_", z], elewise_mul)
 
构建计算图
- 功能
创建图op。
 - 原型
GraphBuilder.Build () -> torch_atb.GraphBuilder
 - 返回值
atb::OperationWrapper - 注意
调用接口失败,会抛出异常终止运行。
 - 使用示例
1 2 3
builder = torch_atb.Builder("Graph") # 省略添加输入,节点,标记输出过程 Graph = builder.build()
 
设置实际执行流
- 功能
设置实际执行流。
 - 原型
GraphBuilder.SetExecuteStreams (executeStreams: std::vector<std::uintptr_t>) -> torch_atb.GraphBuilder
 - 参数
实际执行流(aclStream)列表
 - 注意
调用接口失败,会抛出异常终止运行。
 - 使用示例
1 2 3 4 5 6 7 8 9 10 11
def create_streams(): streams = [] for i in range(2): stream, ret = acl.rt.create_stream() if ret != ACL_SUCCESS: exit(0) streams.append(stream) return streams builder = torch_atb.Builder("Graph") builder.set_execute_streams(create_streams())
 
父主题: 组图接口