TensorTrait

功能说明

GlobalTensorLocalTensor通过ShapeInfo类型的成员变量来保存shape信息,并可以通过SetShapeInfo和GetShapeInfo方法进行设置或获取。这些方法通常用于算子实现内部的shape信息保存和传递。如果不需要使用ShapeInfo功能,可以通过TensorTrait定义不含ShapeInfo的GlobalTensor和LocalTensor,从而降低内存占用并提升运行性能。

定义原型

1
2
3
4
template <typename T>
struct TensorTrait {
    using LiteType = T;
};

参数说明

表1 TensorTrait结构体模板参数说明

参数名

描述

T

只支持如下基础数据类型:int4b_t、uint8_t、int8_t、int16_t、uint16_t、bfloat16_t、int32_t、uint32_t、int64_t、uint64_t、float、half 。

通过TensorTrait可以得到一个使用TensorTrait表达的Tensor数据类型:在TensorTrait结构体内部,使用using关键字定义了一个类型别名LiteType,与模板参数T类型一致

通过TensorTrait定义的LocalTensor/GlobalTensor不包含ShapeInfo信息。

例如:

LocalTensor<float>对应的不含ShapeInfo信息的Tensor为LocalTensor<TensorTrait<float>>。

约束说明

调用示例