昇腾社区首页
中文
注册
开发者
下载

Build

函数功能

根据前期的设置,构建InferShapeContext,返回一个ContextHolder<InferShapeContext>对象。

函数原型

1
ContextHolder<InferShapeContext> Build()

参数说明

返回值说明

返回一个ContextHolder<InferShapeContext>对象。通过调用GetContext()方法可获取InferShapeContext指针。

约束说明

  • 所有通过指针传入OpInferShapeContextBuilder的参数,其内存所有权归调用者所有;调用者必须确保这些指针在ContextHolder对象的整个生命周期内有效。
  • ContextHolder析构时会自动释放内部上下文资源。请勿手动释放GetContext()返回的指针。

调用示例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#include "base/context_builder/op_infer_shape_context_builder.h"
OpInferShapeContextBuilder ctx_builder;
StorageShape shape0 = {{1, 2, 3, 4}, {4, 3, 2, 1}};
StorageShape shape1 = {{2, 3, 4, 5}, {5, 4, 3, 2}};
StorageShape shape2 = {{3, 4, 5, 6}, {6, 5, 4, 3}};
StorageShape shape3 = {{4, 5, 6, 7}, {7, 6, 5, 4}};
StorageFormat format{FORMAT_ND, FORMAT_FRACTAL_NZ, {}};
gert::Tensor tensor0{shape0, format, ge::DT_FLOAT};
gert::Tensor tensor1{shape1, format, ge::DT_FLOAT};
gert::Tensor tensor2{shape2, format, ge::DT_FLOAT};
gert::Tensor tensor3{shape3, format, ge::DT_FLOAT};
std::vector<gert::Tensor *> input_tensors = {&tensor0, &tensor1, &tensor2, &tensor3};
auto holder = ctx_builder.OpType("DIY")
                  .OpName("diy_1")
                  .IOInstanceNum({1, 1, 1, 1}, {1})
                  .OutputTensorDesc(0, ge::DT_FLOAT, ge::FORMAT_ND, ge::FORMAT_NCHW)
                  .InputTensors(input_tensors)
                  .Build();
auto ctx = holder.GetContext();
EXPECT_NE(ctx, nullptr);
auto ctx_compute_node_info = ctx->GetComputeNodeInfo();
EXPECT_NE(ctx_compute_node_info, nullptr);
EXPECT_EQ(std::string(ctx_compute_node_info->GetNodeType()), std::string("DIY"));
EXPECT_EQ(std::string(ctx_compute_node_info->GetNodeName()), std::string("diy_1"));
EXPECT_EQ(ctx_compute_node_info->GetIrInputsNum(), 4);
EXPECT_EQ(ctx_compute_node_info->GetIrOutputsNum(), 1);
EXPECT_EQ(ctx_compute_node_info->GetInputsNum(), 4);
EXPECT_EQ(ctx_compute_node_info->GetOutputsNum(), 1);
const CompileTimeTensorDesc *info_input_0 = ctx_compute_node_info->GetInputTdInfo(0);
EXPECT_NE(info_input_0, nullptr);
EXPECT_EQ(info_input_0->GetStorageFormat(), ge::FORMAT_FRACTAL_NZ);
EXPECT_EQ(info_input_0->GetOriginFormat(), ge::FORMAT_ND);
EXPECT_NE(ctx->GetInputShape(0), nullptr);
EXPECT_EQ(*(ctx->GetInputShape(0)), shape0.GetOriginShape());
EXPECT_NE(ctx->GetInputShape(1), nullptr);
EXPECT_EQ(*(ctx->GetInputShape(1)), shape1.GetOriginShape());
EXPECT_NE(ctx->GetInputShape(2), nullptr);
EXPECT_EQ(*(ctx->GetInputShape(2)), shape2.GetOriginShape());
EXPECT_NE(ctx->GetInputShape(3), nullptr);
EXPECT_EQ(*(ctx->GetInputShape(3)), shape3.GetOriginShape());
EXPECT_NE(ctx->GetOutputShape(0), nullptr);
EXPECT_EQ(ctx->GetOutputShape(0)->GetDimNum(), 0);
EXPECT_EQ(ctx->GetComputeNodeInfo()->GetOutputTdInfo(0)->GetDataType(), DT_FLOAT);
EXPECT_EQ(ctx->GetComputeNodeInfo()->GetOutputTdInfo(0)->GetOriginFormat(), FORMAT_ND);
EXPECT_EQ(ctx->GetComputeNodeInfo()->GetOutputTdInfo(0)->GetStorageFormat(), FORMAT_NCHW);