简介
TensorTrait数据结构是描述Tensor相关信息的基础模板类,包含Tensor的数据类型、逻辑位置和Layout内存布局。借助模板元编程技术,该类在编译时完成计算和代码生成,从而降低运行时开销。
需要包含的头文件
1 | #include "kernel_operator_tensor_trait.h" |
原型定义
template <typename T, TPosition pos = TPosition::GM, typename LayoutType = Layout<Shape<>, Stride<>>> struct TensorTrait { using LiteType = T; using LiteLayoutType = LayoutType; static constexpr const TPosition tPos = pos; // 该常量成员为后续功能拓展做预留 public: __aicore__ inline TensorTrait(const LayoutType& t = {}); __aicore__ inline LayoutType& GetLayout(); __aicore__ inline const LayoutType& GetLayout() const; __aicore__ inline void SetLayout(const LayoutType& t); };
模板参数
参数名 |
描述 |
---|---|
T |
只支持如下基础数据类型:int4b_t、uint8_t、int8_t、int16_t、uint16_t、bfloat16_t、int32_t、uint32_t、int64_t、uint64_t、float、half 。 在TensorTrait结构体内部,使用using关键字定义了一个类型别名LiteType,与模板参数T类型一致。 通过TensorTrait定义的LocalTensor/GlobalTensor不包含ShapeInfo信息。 例如:LocalTensor<float>对应的不含ShapeInfo信息的Tensor为LocalTensor<TensorTrait<float>>。 |
pos |
数据存放的逻辑位置,Tposition类型,默认为TPosition::GM。 |
LayoutType |
Layout数据类型,默认为空类型,即Layout<Shape<>, Stride<>>。 输入的数据类型LayoutType,需满足约束说明。 |
成员函数
__aicore__ inline TensorTrait(const LayoutType& t = {}) __aicore__ inline LayoutType& GetLayout() __aicore__ inline const LayoutType& GetLayout() const __aicore__ inline void SetLayout(const LayoutType& t)
相关接口
// TensorTrait结构构造方法 template <typename T, TPosition pos, typename LayoutType> __aicore__ inline constexpr auto MakeTensorTrait(const LayoutType& t) // is_tensorTrait原型定义 template <typename T> struct is_tensorTrait
约束说明
- 同一接口不支持同时输入TensorTrait类型的GlobalTensor/LocalTensor和非TensorTrait类型的GlobalTensor/LocalTensor。
- 非TensorTrait类型和TensorTrait类型的GlobalTensor/LocalTensor相互之间不支持拷贝构造和赋值运算符。
- TensorTrait特性当前仅支持如下接口:
- 当前暂不支持TensorTrait结构配置pos、LayoutType模板参数,和API配合使用时,需要使用构造函数构造TensorTrait,pos、LayoutType保持默认值即可。
- DataCopy切片数据搬运接口需要ShapeInfo信息,不支持输入TensorTrait类型的GlobalTensor/LocalTensor。
表2 TensorTrait特性支持的接口列表 接口分类
接口名称
基础API>内存管理与同步控制>TQue/TQueBind
AllocTensor、FreeTensor、EnQue、DeQue
基础API>矢量计算>基础算术
Exp、Ln、Abs、Reciprocal、Sqrt、Rsqrt、Not、Relu、Add、Sub、Mul、Div、Max、Min、Adds、Muls、Maxs、Mins、LeakyRelu
基础API>矢量计算>逻辑计算
And、Or、ShiftLeft、ShiftRight
基础API>矢量计算>复合计算
AddRelu、AddReluCast、AddDeqRelu、SubRelu、SubReluCast、MulAddDst、FusedMulAdd、FusedMulAddRelu
基础API>数据搬运
DataCopy、Copy
基础指令>ISASI(体系结构相关)>矩阵计算
InitConstValue、LoadData、LoadDataWithTranspose、SetAippFunctions、LoadImageToLocal、LoadUnzipIndex、LoadDataUnzip、LoadDataWithSparse、SetFmatrix、SetLoadDataBoundary、SetLoadDataRepeat、SetLoadDataPaddingValue、Mmad、MmadWithSparse、Fixpipe、SetFixPipeConfig、SetFixpipeNz2ndFlag、SetFixpipePreQuantFlag、BroadCastVecToMM、SetHF32Mode、SetHF32TransMode、SetMMLayoutTransform、CheckLocalMemoryIA、Gemm