昇腾社区首页
中文
注册
开发者
下载

Build

函数功能

根据之前的设置,构建InferDataTypeContext,返回一个ContextHolder<InferDataTypeContext>对象。

函数原型

1
ContextHolder<InferDataTypeContext> Build()

参数说明

返回值说明

返回一个ContextHolder<InferDataTypeContext>对象。通过调用GetContext()方法可获取InferDataTypeContext指针。

约束说明

  • 所有通过指针传入OpInferDataTypeContextBuilder的参数,其内存所有权归调用者。调用者必须确保这些指针在ContextHolder对象的整个生命周期内有效。
  • ContextHolder析构时会自动释放内部上下文资源。请勿手动释放GetContext()返回的指针。

调用示例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#include "base/context_builder/op_infer_datatype_context_builder.h"
OpInferDataTypeContextBuilder ctx_builder;
ge::DataType dtype0 = ge::DT_FLOAT;
ge::DataType dtype1 = ge::DT_FLOAT16;
ge::DataType dtype2 = ge::DT_FLOAT;
ge::DataType dtype3 = ge::DT_FLOAT16;
ge::DataType dtype4 = ge::DT_FLOAT16;
std::vector<ge::DataType *> input_dtype_ref = {&dtype0, &dtype1, &dtype2, &dtype3};
std::vector<ge::DataType *> output_dtype_ref = {&dtype4};
auto holder = ctx_builder.OpType("Concat")
                  .OpName("concat_1")
                  .IOInstanceNum({4}, {1})
                  .InputTensorDesc(0, dtype0, ge::FORMAT_ND, ge::FORMAT_ND)
                  .InputTensorDesc(1, dtype1, ge::FORMAT_ND, ge::FORMAT_ND)
                  .InputTensorDesc(2, dtype2, ge::FORMAT_ND, ge::FORMAT_ND)
                  .InputTensorDesc(3, dtype3, ge::FORMAT_ND, ge::FORMAT_ND)
                  .OutputTensorDesc(0, ge::FORMAT_ND, ge::FORMAT_ND)
                  .Build();
auto ctx = holder.GetContext();
EXPECT_NE(ctx, nullptr);
auto ctx_compute_node_info = ctx->GetComputeNodeInfo();
EXPECT_NE(ctx_compute_node_info, nullptr);
EXPECT_EQ(std::string(ctx_compute_node_info->GetNodeType()), std::string("Concat"));
EXPECT_EQ(std::string(ctx_compute_node_info->GetNodeName()), std::string("concat_1"));
EXPECT_EQ(ctx_compute_node_info->GetIrInputsNum(), 1);
EXPECT_EQ(ctx_compute_node_info->GetIrOutputsNum(), 1);
EXPECT_EQ(ctx_compute_node_info->GetInputsNum(), 4);
EXPECT_EQ(ctx_compute_node_info->GetOutputsNum(), 1);
const CompileTimeTensorDesc *info_input_0 = ctx_compute_node_info->GetInputTdInfo(0);
EXPECT_NE(info_input_0, nullptr);
EXPECT_EQ(info_input_0->GetStorageFormat(), ge::FORMAT_ND);
EXPECT_EQ(info_input_0->GetOriginFormat(), ge::FORMAT_ND);
ge::DataType expected_datatype_0 = ge::DT_FLOAT;
ge::DataType expected_datatype_1 = ge::DT_FLOAT16;
EXPECT_EQ(ctx->GetInputDataType(0), expected_datatype_0);
EXPECT_EQ(ctx->GetInputDataType(1), expected_datatype_1);
EXPECT_EQ(ctx->GetInputDataType(2), expected_datatype_0);
EXPECT_EQ(ctx->GetInputDataType(3), expected_datatype_1);
EXPECT_EQ(ctx->GetOutputDataType(0), ge::DT_MAX);