GetRequiredInputShapeRange
Description
Obtains the pointer to a required input shape range based on the input index in the operator prototype definition.
Prototype
const Range<Shape> *GetRequiredInputShapeRange(const size_t ir_index) const
Parameters
Parameter |
Input/Output |
Description |
|---|---|---|
ir_index |
Input |
Index of the required input in the operator IR prototype definition, starting from 0. |
Returns
Pointer to the shape range. If ir_index is invalid, a null pointer is returned.
Restrictions
None
Example
const auto infer_shape_range_func = [](gert::InferShapeRangeContext *context) -> graphStatus {
auto input_shape_range = context->GetRequiredInputShapeRange(0U);
auto output_shape_range = context->GetOutputShapeRange(0U);
output_shape_range->SetMin(const_cast<gert::Shape *>(input_shape_range->GetMin()));
output_shape_range->SetMax(const_cast<gert::Shape *>(input_shape_range->GetMax()));
return GRAPH_SUCCESS;
};
Parent topic: InferShapeRangeContext