简介
图构建器类,用于C++用户构建计算图,该类提供了创建输入等图元素的方法, 以及设置私有属性、构建最终图、服务C++的ES generated API的内联函数内部调用C的ES generated API转换逻辑等功能。
需要包含的头文件
1 | #include "es_graph_builder.h" |
Public成员函数
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 | explicit EsGraphBuilder(const char *name); template <typename T> std::unique_ptr<Tensor> CreateTensor(const std::vector<T> &value, const std::vector<int64_t> &dims, DataType dt,Format format = FORMAT_ND); std::unique_ptr<Tensor> CreateBoolTensor(const std::vector<bool> &value, const std::vector<int64_t> &dims, Format format = FORMAT_ND); EsTensorHolder CreateInput(int64_t index, const char *name, const char *type); template <size_t inputs_num> std::array<EsTensorHolder, inputs_num> CreateInputs(size_t start_index = 0); std::vector<EsTensorHolder> CreateInputs(size_t inputs_num, size_t start_index = 0); EsTensorHolder CreateInput(int64_t index); EsTensorHolder CreateInput(int64_t index, const char *name); EsTensorHolder CreateInput(int64_t index, const char *name, ge::DataType data_type, ge::Format format, const std::vector<int64_t> &shape); EsTensorHolder CreateConst(const std::vector<int64_t> &value, const std::vector<int64_t> dims); EsTensorHolder CreateConst(const std::vector<int32_t> &value, const std::vector<int64_t> dims); EsTensorHolder CreateConst(const std::vector<uint64_t> &value, const std::vector<int64_t> dims); EsTensorHolder CreateConst(const std::vector<uint32_t> &value, const std::vector<int64_t> dims); EsTensorHolder CreateConst(const std::vector<float> &value, const std::vector<int64_t> dims); template <typename T> EsTensorHolder CreateConst(const std::vector<T> &value, const std::vector<int64_t> &dims, ge::DataType dt,ge::Format format = FORMAT_ND); EsTensorHolder CreateVector(const std::vector<int64_t> &value); EsTensorHolder CreateVector(const std::vector<int32_t> &value); EsTensorHolder CreateVector(const std::vector<uint64_t> &value); EsTensorHolder CreateVector(const std::vector<uint32_t> &value); EsTensorHolder CreateVector(const std::vector<float> &value); EsTensorHolder CreateScalar(int64_t value); EsTensorHolder CreateScalar(int32_t value); EsTensorHolder CreateScalar(uint64_t value); EsTensorHolder CreateScalar(uint32_t value); EsTensorHolder CreateScalar(float value); EsTensorHolder CreateVariable(int32_t index, const char *name); Status SetAttr(const char *attr_name, int64_t value); Status SetAttr(const char *attr_name, const char *value); Status SetAttr(const char *attr_name, bool value); static int64_t SetOutput(const EsTensorHolder &tensor, int64_t index); std::unique_ptr<Graph> BuildAndReset() const; std::unique_ptr<Graph> BuildAndReset(const std::vector<EsTensorHolder> &outputs); EsCGraphBuilder *GetCGraphBuilder() const; |
对外函数
1 2 3 4 5 6 7 8 9 10 11 12 | template <typename T> std::unique_ptr<Tensor> CreateTensor(const T *value, const int64_t *dims, int64_t dim_num, DataType dt, Format format = FORMAT_ND); template <typename T> std::unique_ptr<Tensor> CreateTensorFromFile(const char *data_file_path, const std::vector<int64_t> &dims, DataType dt, Format format = FORMAT_ND); template <typename T> EsCTensorHolder *EsCreateConst(EsCGraphBuilder *graph, const T *value, const int64_t *dims, int64_t dim_num, ge::DataType dt, ge::Format format = FORMAT_ND); inline std::vector<EsCTensorHolder *> TensorsToEsCTensorHolders(const std::vector<EsTensorHolder> &tensors) inline std::vector<C_DataType> DataTypesToEsCDataTypes(const std::vector<ge::DataType> &data_types) template <typename T> std::pair<std::vector<const T *>, std::vector<int64_t>> ListListTypeToPtrAndCounts(const std::vector<std::vector<T>> &list_list_type) inline std::vector<EsCGraph *> GeGraphsToEsCGraphs(std::vector<std::unique_ptr<ge::Graph>> graphs) |
父主题: EsGraphBuilder