TensorDesc
功能说明
TensorDesc用于储存ListTensor.GetDesc()中根据index获取对应的Tensor描述信息。
定义原型
1 2 3 4 5 6 7 8 9 10 |
template<class T> class TensorDesc { TensorDesc(); ~TensorDesc(); void SetShapeAddr(uint64_t* shapePtr); uint64_t GetDim(); uint64_t GetIndex(); uint64_t GetShape(uint32_t offset); T* GetDataPtr(); GlobalTensor<T> GetDataObj(); } |
函数说明
函数名称 |
入参说明 |
含义 |
---|---|---|
SetShapeAddr |
shapePtr: 用于储存shape信息的地址 |
配置用于储存shape信息的地址。 |
GetDim |
- |
获取Tensor的维度。 |
GetIndex |
- |
获取TensorDesc在ListTensorDesc中对应的索引值。 |
GetShape |
offset: 索引值 |
获取对应维度的shape信息。 |
GetDataPtr |
- |
获取储存Tensor数据地址。 |
GetDataObj |
- |
将数据指针置于GlobalTensor中并返回该GlobalTensor。 |
支持的型号
Atlas推理系列产品AI Core
Atlas A2训练系列产品/Atlas 800I A2推理产品
注意事项
使用GetDesc前需要先调用TensorDesc.SetShapeAddr为desc指定用于储存shape信息的地址,调用GetDesc后会将shape信息写入该地址。
调用示例
示例中待解析的src_gm内存排布如下图所示:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
AscendC::ListTensorDesc listTensorDesc(reinterpret_cast<__gm__ void *>(srcGm)); // src_gm为待解析的gm地址 uint32_t size = listTensorDesc.GetSize(); // size = 2 auto dataPtr0 = listTensorDesc.GetDataPtr<int32_t>(0); // 获取ptr0 auto dataPtr1 = listTensorDesc.GetDataPtr<int32_t>(1); // 获取ptr1 uint64_t buf[100] = {0}; // 示例中Tensor的dim为3, 此处的100表示预留足够大的空间 AscendC::TensorDesc<int32_t> desc; desc.SetShapeAddr(buf); // 为desc指定用于储存shape信息的地址 listTensorDesc.GetDesc(desc, 0); // 获取索引0的shape信息 uint64_t dim = desc.GetDim(); // dim = 3 uint64_t idx = desc.GetIndex(); // idx = 0 uint64_t shape[3] = {0}; for (uint32_t i = 0; i < desc.GetDim(); i++) { shape[i] = desc.GetShape(i); // GetShape(0) = 1, GetShape(1) = 2, GetShape(2) = 3 } auto ptr = desc.GetDataPtr(); |
父主题: 数据类型定义