昇腾社区首页
中文
注册
开发者
下载

Build

函数功能

根据之前的设置,构建TilingContext,返回一个ContextHolder<TilingContext>对象。

函数原型

1
ContextHolder<TilingContext> Build()

参数说明

返回值说明

返回一个 ContextHolder<TilingContext>对象。通过调用GetContext()方法可获取TilingContext指针。

约束说明

  • 所有通过指针传入的参数,其内存所有权归调用者所有;调用者必须确保这些指针在ContextHolder对象的整个生命周期内有效。
  • ContextHolder析构时会自动释放内部上下文资源。请勿手动释放GetContext()返回的指针。

调用示例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#include "base/context_builder/op_tiling_context_builder.h"
auto workspace_size_holer = gert::ContinuousVector::Create<size_t>(4096);
auto ws_ptr = reinterpret_cast<gert::ContinuousVector *>(workspace_size_holer.get());
gert::Shape shape_0{1, 1, 1, 1, 1};
gert::Shape shape_1{10, 10, 10, 10, 20};
gert::Shape shape_2{1, 1, 1, 1, 1};
gert::Shape shape_3{10, 10, 10, 10, 20};
gert::Shape resultShape{10, 10, 10, 10, 20};
gert::StorageShape x({1, 1, 1, 1, 1}, {1, 1, 1, 1, 1});
gert::StorageShape resultIn({10, 10, 10, 10, 20}, {10, 10, 10, 10, 20});
gert::StorageShape gamma({1, 1, 1, 1, 1}, {1, 1, 1, 1, 1});
gert::StorageShape beta({10, 10, 10, 10, 20}, {10, 10, 10, 10, 20});
gert::StorageShape result({10, 10, 10, 10, 20}, {10, 10, 10, 10, 20});
uint8_t data_x[1] = {9};
gert::Tensor x_tensor(x, {ge::FORMAT_NCDHW, ge::FORMAT_RESERVED, ExpandDimsType()}, TensorPlacement::kOnHost,
                      ge::DT_FLOAT, (void *) data_x);
gert::Tensor resultIn_tensor(resultIn, {ge::FORMAT_NCDHW, ge::FORMAT_RESERVED, ExpandDimsType()},
                             TensorPlacement::kOnHost, ge::DT_FLOAT, nullptr);
gert::Tensor gammax_tensor(gamma, {ge::FORMAT_NCDHW, ge::FORMAT_RESERVED, ExpandDimsType()}, TensorPlacement::kOnHost,
                           ge::DT_FLOAT, nullptr);
gert::Tensor beta_tensor(beta, {ge::FORMAT_NCDHW, ge::FORMAT_RESERVED, ExpandDimsType()}, TensorPlacement::kOnHost,
                         ge::DT_FLOAT, nullptr);
gert::Tensor result_tensor(result, {ge::FORMAT_NCDHW, ge::FORMAT_RESERVED, ExpandDimsType()},
                           TensorPlacement::kOnHost, ge::DT_FLOAT, nullptr);
uint8_t tmp_compile_info[] = XX; // XX代表Fake数据
uint8_t tmp_platform_info[] = XX;// XX代表Fake数据
int32_t deterministic = 1;
OpTilingContextBuilder ctx_builder;
auto holder = ctx_builder.OpName("tmp")
                  .OpType("DIY")
                  .IONum(4, 1)
                  .AppendAttr(int64_t(1))
                  .AppendAttr(bool(true))
                  .AppendAttr(float(0.3))
                  .AppendAttr(AscendString("my_info"))
                  .AppendAttr(std::vector<bool>({true, false, true}))
                  .AppendAttr(std::vector<int64_t>({1, 2, 3}))
                  .AppendAttr(std::vector<float>({0.1, 0.2, 0.3}))
                  .AppendAttr(std::vector<AscendString>({"123", "234"}))
                  .AppendAttr(std::vector<std::vector<int64_t>>({{1, 2, 3}, {4, 5, 6}}))
                  .TilingDataSize(100)
                  .Workspace(ws_ptr)
                  .CompileInfo(tmp_compile_info)
                  .Deterministic(deterministic)
                  .PlatformInfo(tmp_platform_info)
                  .InputTensors({&x_tensor, &resultIn_tensor, &gammax_tensor, &beta_tensor})
                  .OutputTensors({&result_tensor})
                  .Build();
auto ctx = holder.GetContext();
EXPECT_NE(ctx, nullptr);
auto ctx_compute_node_info = ctx->GetComputeNodeInfo();
EXPECT_NE(ctx_compute_node_info, nullptr);
EXPECT_EQ(ctx->GetCompileInfo(), tmp_compile_info);
EXPECT_EQ(ctx->GetInputShape(0)->GetOriginShape(), shape_0);
EXPECT_EQ(ctx->GetInputShape(0)->GetStorageShape(), shape_0);
EXPECT_EQ(ctx->GetInputTensor(0)->GetAddr(), data_x);
EXPECT_EQ(ctx->GetInputTensor(0), &x_tensor);
EXPECT_EQ(ctx->GetInputTensor(0)->GetStorageShape(), x_tensor.GetStorageShape());
EXPECT_EQ(ctx->GetInputTensor(0)->GetOriginShape(), x_tensor.GetOriginShape());
EXPECT_EQ(ctx->GetInputTensor(0)->GetSize(), x_tensor.GetSize());
EXPECT_EQ(ctx->GetOutputShape(0)->GetOriginShape(), resultShape);
EXPECT_EQ(ctx->GetOutputShape(0)->GetStorageShape(), resultShape);
EXPECT_EQ(static_cast<void *>(ctx->GetWorkspaceSizes(4096)), static_cast<const void *>(ws_ptr->GetData()));
EXPECT_EQ(static_cast<void *>(ctx->GetPlatformInfo()), static_cast<void *>(tmp_platform_info));
EXPECT_EQ(ctx->GetDeterministic(), deterministic);
EXPECT_EQ(static_cast<void *>(ctx->GetRawTilingData()), static_cast<void *>(tmp_tiling_data.get()));
EXPECT_EQ(*(ctx->GetAttrs()->GetInt(0)), 1);
EXPECT_EQ(*(ctx->GetAttrs()->GetBool(1)), true);
EXPECT_FLOAT_EQ(*(ctx->GetAttrs()->GetFloat(2)), 0.3);
auto str_ptr = ctx->GetAttrs()->GetStr(3);
EXPECT_EQ(strcmp(str_ptr, "my_info"), 0);
auto bool_vec = ctx->GetAttrs()->GetAttrPointer<TypedContinuousVector<bool>>(4);
EXPECT_EQ(bool_vec->GetData()[0], true);
EXPECT_EQ(bool_vec->GetData()[1], false);
EXPECT_EQ(bool_vec->GetData()[2], true);
EXPECT_EQ(ctx->GetAttrs()->GetListInt(5)->GetData()[0], 1);
EXPECT_EQ(ctx->GetAttrs()->GetListInt(5)->GetData()[1], 2);
EXPECT_EQ(ctx->GetAttrs()->GetListInt(5)->GetData()[2], 3);
EXPECT_FLOAT_EQ(ctx->GetAttrs()->GetListFloat(6)->GetData()[0], 0.1);
EXPECT_FLOAT_EQ(ctx->GetAttrs()->GetListFloat(6)->GetData()[1], 0.2);
EXPECT_FLOAT_EQ(ctx->GetAttrs()->GetListFloat(6)->GetData()[2], 0.3);
auto int_vec_vec = ctx->GetAttrs()->GetListListInt(8);
EXPECT_EQ(((int64_t *) (int_vec_vec->Get(0)->GetData()))[0], 1);
EXPECT_EQ(((int64_t *) (int_vec_vec->Get(0)->GetData()))[1], 2);
EXPECT_EQ(((int64_t *) (int_vec_vec->Get(0)->GetData()))[2], 3);
EXPECT_EQ(((int64_t *) (int_vec_vec->Get(1)->GetData()))[0], 4);
EXPECT_EQ(((int64_t *) (int_vec_vec->Get(1)->GetData()))[1], 5);
EXPECT_EQ(((int64_t *) (int_vec_vec->Get(1)->GetData()))[2], 6);