Passing Attribute Information Through TilingData
If an operator contains attribute information, the attribute information can be passed to the kernel through TilingData to participate in the compute of the operator kernel function in the kernel. The ReduceMaxCustom operator is used as an example. This operator is used to return the maximum value of the input data by dim and return the index. The ReduceMaxCustom operator has two attributes: reduceDim and isKeepDim. reduceDim indicates the dimension based on which the reduce operation is performed. isKeepDim indicates whether to keep the output dimension the same as the input dimension. In this sample, the reduce operation can be performed only on the last dimension, and the input data type is half.
- The definition of TilingData of the ReduceMaxCustom operator is as follows (reduceAxisLen is the focus): The reduceAxisLen parameter indicates the length of the reduceDim axis, that is, the length of the last dimension. This parameter will be passed to the kernel through TilingData for computation.
1 2 3 4 5 6 7 8 9 10 11 12 13
#ifndef REDUCE_MAX_CUSTOM_TILING_H #define REDUCE_MAX_CUSTOM_TILING_H #include "register/tilingdata_base.h" namespace optiling { BEGIN_TILING_DATA_DEF(ReduceMaxTilingData) TILING_DATA_FIELD_DEF(uint32_t, reduceAxisLen); // Add the tiling field to specify the length of the reduceDim axis. //Definitions of other TilingData parameters. ... END_TILING_DATA_DEF; // Register the TilingData class with the corresponding ReduceMaxCustom operator. REGISTER_TILING_DATA_CLASS(ReduceMaxCustom, ReduceMaxTilingData) } #endif // REDUCE_MAX_CUSTOM_TILING_H
- Tiling of the ReduceMaxCustom operator is implemented as follows: The process of passing attribute information through TilingData is as follows: Obtain the reduceDim attribute value from attr through TilingContext, obtain the reduceDim axis length based on the reduceDim attribute value, and set the length to TilingData.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
namespace optiling { static ge::graphStatus TilingFunc(gert::TilingContext* context) { ReduceMaxTilingData tiling; // Obtain the reduceDim attribute value from attr. Because reduceDim is the first attribute, the index value passed by GetAttrPointer is 0. const gert::RuntimeAttrs* attrs = context->GetAttrs(); const uint32_t* reduceDim = attrs->GetAttrPointer<uint32_t>(0); // Obtain the length of the reduceDim axis. const gert::StorageShape* xShapePtr = context->GetInputShape(0); const gert::Shape& xShape = xShapePtr->GetStorageShape(); const uint32_t reduceAxisLen = xShape.GetDim(*reduceDim); // Compute the values of member variables except reduceAxisLen in TilingData. ... // Set reduceAxisLen to the tiling structure and pass it to the kernel function. tiling.set_reduceAxisLen(reduceAxisLen); // Set the values of member variables except reduceAxisLen in TilingData. ... // Serialize and save TilingData. tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); ... return ge::GRAPH_SUCCESS; }} // namespace optiling
Parent topic: Tiling Implementation on the Host