TensorTrait Overview
The TensorTrait data structure is a basic template class that describes tensor information, including the data type, logical location, and layout memory layout of the tensor. With the help of template metaprogramming technology, this class completes computation and code generation at compile time, reducing the runtime overhead.
Header Files to Be Included
1 | #include "kernel_operator_tensor_trait.h" |
Prototype
template <typename T, TPosition pos = TPosition::GM, typename LayoutType = Layout<Shape<>, Stride<>>>
struct TensorTrait {
using LiteType = T;
using LiteLayoutType = LayoutType;
static constexpr const TPosition tPos = pos; // This constant member is reserved for future function extension.
public:
__aicore__ inline TensorTrait(const LayoutType& t = {});
__aicore__ inline LayoutType& GetLayout();
__aicore__ inline const LayoutType& GetLayout() const;
__aicore__ inline void SetLayout(const LayoutType& t);
};
Template parameter
Parameter |
Description |
|---|---|
T |
Only the following basic data types are supported: int4b_t, uint8_t, int8_t, int16_t, uint16_t, bfloat16_t, int32_t, uint32_t, int64_t, uint64_t, float, and half. In the TensorTrait structure, the using keyword is used to define a type alias LiteType, which is the same as the template parameter T type. The LocalTensor/GlobalTensor defined by TensorTrait does not contain ShapeInfo. For example, the tensor that does not contain the ShapeInfo information corresponding to LocalTensor<float> is LocalTensor<TensorTrait<float>>. |
pos |
Logical position where data is stored. The value is of the Tposition type. The default value is TPosition::GM. |
LayoutType |
Layout data type. The default value is empty, that is, Layout<Shape<>, Stride<>>. The input data type LayoutType must meet the Restrictions. |
Member Function
__aicore__ inline TensorTrait(const LayoutType& t = {})
__aicore__ inline LayoutType& GetLayout()
__aicore__ inline const LayoutType& GetLayout() const
__aicore__ inline void SetLayout(const LayoutType& t)
Related APIs
// TensorTrait structure construction method template <typename T, TPosition pos, typename LayoutType> __aicore__ inline constexpr auto MakeTensorTrait(const LayoutType& t) // is_tensorTrait prototype definition template <typename T> struct is_tensorTrait
Restrictions
- The same API does not support the input of GlobalTensor and LocalTensor of the TensorTrait and non-TensorTrait types at the same time.
- Copy constructors and assignment operators are not supported between GlobalTensor and LocalTensor of the non-TensorTrait and TensorTrait types.
- Currently, the TensorTrait feature supports only the following APIs.
- When used with APIs, the TensorTrait structure does not support the configuration of the pos and LayoutType template parameters. You need to use the constructor to construct the TensorTrait, and retain the default values of pos and LayoutType.
- The DataCopy API for data movement requires the ShapeInfo information and does not support the input of GlobalTensor or LocalTensor of the TensorTrait type.
Table 2 APIs supported by TensorTrait API Category
API
Basic API> resource management > TQue/TQueBind
AllocTensor, FreeTensor, EnQue, DeQue
Basic API> vector computation > Basic arithmetic
Exp, Ln, Abs, Reciprocal, Sqrt, Rsqrt, Relu, Add, Sub, Mul, Div, Max, Min, Adds, Muls, Maxs, Mins, VectorPadding, BilinearInterpolation, LeakyRelu
Basic API> vector computation > Logical computation
And, Or
Basic API> vector computation > Compound computation
CastDeq, AddRelu, AddDeqRelu, SubRelu, MulAddDst, FusedMulAdd, FusedMulAddRelu, AddReluCast, SubReluCast, MulCast
Basic APIs > Data movement
DataCopy, Copy
Basic API> matrix computation
InitConstValue, LoadData, LoadDataWithTranspose, SetAippFunctions, LoadImageToLocal, LoadUnzipIndex, LoadDataUnzip, LoadDataWithSparse, Mmad, MmadWithSparse, BroadCastVecToMM, Gemm, Fixpipe
Basic API> vector computation > Comparison and selection
Compare, GetCmpMask, SetCmpMask, Select, GatherMask
Basic API> vector computation > Type conversion
Cast
Basic API> vector computation > Reduction computation
ReduceMax, BlockReduceMax, WholeReduceMax, ReduceMin, BlockReduceMin, WholeReduceMin, ReduceSum, BlockReduceSum, WholeReduceSum, RepeatReduceSum, PairReduceSum
Basic API> vector computation > Data conversion
Transpose, TransDataTo5HD
Basic API> vector computation > Data filling
Brcb
Basic API> vector computation > Discretization and aggregation
Gather, Gatherb, Scatter
Basic API> vector computation > Sorting and combination (ISASI)
ProposalConcat, ProposalExtract, RpSort16, MrgSort4, Sort32