Defining a Constant Node (Const)
The Const operator implements constant tensors, such as weights and offsets.
Const operator prototype definition is as follows.
1 2 3 4 5 | REG_OP(Const) .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_BF16, DT_INT4, DT_INT8, DT_INT16, DT_UINT16, \ DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) .ATTR(value, Tensor, Tensor()) .OP_END_FACTORY_REG(Const) |
Constructing Weight Data
Create a Const operator instance based on the operator prototype definition. The attribute value is weighttensor1.
1 2 3 4 5 6 7 8 9 10 11 12 | // Construct weighttensor1. TensorDesc weight_desc(ge::Shape({1,3,3,3}), FORMAT_NCHW, DT_INT8); int bs_size_weight = 27; int8_t * bs_inputData_weight = nullptr; bs_inputData_weight = new int8_t[bs_size_weight]; for (int i=0; i<bs_size_weight; ++i) { *(bs_inputData_weight+i) = 1; } Tensor weighttensor1(weight_desc, (uint8_t*)bs_inputData_weight, bs_size_weight*sizeof(int8_t)); // Create a Const operator, and set the attribute value to weighttensor1. auto weight1 = op::Const().set_attr_value(weighttensor1); |
If the prototype input and output of an operator have the same name, the operator is an inplace operator. That is, the output of the operator updates the input. In this scenario, the input cannot be connected to the Const node.
Reading Weight Data from a File
Weight data can also be read directly from the binary file.
1 2 3 4 5 6 7 8 9 10 | // Construct weight_tensor. auto weight_shape = ge::Shape({ 5,17,1,1 }); TensorDesc desc_weight_1(weight_shape, FORMAT_NCHW, DT_INT8); Tensor weight_tensor(desc_weight_1); uint32_t weight_1_len = weight_shape.GetShapeSize(); // "const_0.bin" is the path of the constant file. GetConstTensorFromBin("const_0.bin", weight_tensor, weight_1_len*sizeof(int8_t)); // Create a Const operator, and set the attribute value to weight_tensor. auto conv_weight = op::Const("const_0").set_attr_value(weight_tensor); |
The implementation of the GetConstTensorFromBin function is as follows.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 | bool GetConstTensorFromBin(string path, Tensor &weight, uint32_t len) { // Open the file in binary mode. ifstream in_file(path.c_str(), std::ios::in | std::ios::binary); if(!in_file.is_open()) { std::cout << "failed to open" << path.c_str() << "\n"; return false; } // Move the file pointer to the end of the file to obtain the total file size. in_file.seekg(0, ios_base::end); istream::pos_type file_size = in_file.tellg(); in_file.seekg(0, ios_base::beg); // Move the file pointer back to the beginning of the file. // Check whether the specified len is the same as the actual file size. if(len != file_size) { cout << "Invalid Param.len:" << len << " is not equal with binary size(" << file_size << ")\n"; in_file.close(); return false; } // Allocate a memory buffer for the file content. char* pdata = new(std::nothrow) char[len]; if(pdata == nullptr) { cout << "Invalid Param.len:" << len << " is not equal with binary size(" << file_size << ")\n"; in_file.close(); return false; } // Read data from the file to the buffer. in_file.read(reinterpret_cast<char*>(pdata), len); // Set the data in the buffer to the Tensor object weight. auto status = weight.SetData(reinterpret_cast<uint8_t*>(pdata), len); if(status != ge::GRAPH_SUCCESS) { cout << "Set Tensor Data Failed"<< "\n"; in_file.close(); return false; } in_file.close(); delete[] pdata; return true; } |
The parameters of the GetConstTensorFromBin function are described as follows:
- path: (input) the weight file path (for example, ../data/weight/) for looking up the xx.bin file. You need to manually parse the weight file into a binary file.
- weight: (output) a tensor for the weight data read from the weight file
- len: (input) the size of the weight data
Parent topic: Expression Examples of Operators in a Graph