GetInputTensorRange
Function Usage
Obtains the pointer to the input tensor range based on the operator input index. The input index refers to the actual index after operator instantiation, not the index in the prototype definition.
Prototype
1 | const TensorRange *GetInputTensorRange(const size_t index) const |
Parameters
Parameter |
Input/Output |
Description |
|---|---|---|
index |
Input |
Operator input index, starting from 0. |
Returns
Pointer to the TensorRange type. The definition is as follows:
1 | using TensorRange = Range<Tensor>; |
If index is invalid, a null pointer is returned.
Constraints
If the input is not set to data dependency, when this API is called to obtain a tensor range, only the correct shape, format, and datatype information can be obtained from the tensor, and the actual tensor data address (the obtained address is nullptr) cannot be obtained.
Examples
1 2 3 4 5 6 7 8 9 | const auto infer_shape_range_func = [](gert::InferShapeRangeContext *context) -> graphStatus { auto input_tensor_range = context->GetInputTensorRange(0U); auto output_shape_range = context->GetOutputShapeRange(0U); auto input_tensor = input_tensor_range->GetMax(); auto shape_data = input_tensor->GetData<int64_t>(); auto shape_size = input_tensor->GetShapeSize(); // ... The inference logic of output_shape_range should be added here. return GRAPH_SUCCESS; }; |
Parent topic: InferShapeRangeContext