TensorTrait

功能说明

GlobalTensorLocalTensor中通过ShapeInfo类型的成员变量来保存shape信息,可以通过SetShapeInfo、GetShapeInfo来进行设置或者获取,通常用于算子实现内部的shape信息保存和传递。在不使用上述ShapeInfo功能的情况下,不需要这些信息。此时可以使用TensorTrait定义不含ShapeInfo的GlobalTensor以及LocalTensor,以降低内存占用,提升运行性能。

定义原型

1
2
3
4
template <typename T>
struct TensorTrait {
    using LiteType = T;
};

参数说明

表1 TensorTrait结构体模板参数说明

参数名

描述

T

只支持如下基础数据类型:int4b_t、uint8_t、int8_t、int16_t、uint16_t、bfloat16_t、int32_t、uint32_t、int64_t、uint64_t、float、half 。

通过TensorTrait可以得到一个使用TensorTrait表达的Tensor数据类型:在TensorTrait结构体内部,使用using关键字定义了一个类型别名LiteType,与模板参数T类型一致。

通过TensorTrait定义的LocalTensor/GlobalTensor不包含ShapeInfo信息。

例如:

LocalTensor<float>对应的不含ShapeInfo信息的Tensor为LocalTensor<TensorTrait<float>>。

约束说明

调用示例