TensorDesc

功能说明

TensorDesc用于储存ListTensor.GetDesc()中根据index获取对应的Tensor描述信息。

定义原型

template<class T> class TensorDesc {
    TensorDesc();
    ~TensorDesc();
    void SetShapeAddr(uint64_t* shapePtr);
    uint64_t GetDim();
    uint64_t GetIndex();
    uint64_t GetShape(uint32_t offset);
    T* GetDataPtr();
    GlobalTensor<T> GetDataObj();
}

函数说明

函数名称

入参说明

含义

SetShapeAddr

shapePtr: 用于储存shape信息的地址

配置用于储存shape信息的地址。

GetDim

-

获取Tensor的维度。

GetIndex

-

获取TensorDesc在ListTensorDesc中对应的索引值。

GetShape

offset: 索引值

获取对应维度的shape信息。

GetDataPtr

-

获取储存Tensor数据地址。

GetDataObj

-

将数据指针置于GlobalTensor中并返回该GlobalTensor。

支持的型号

Atlas推理系列产品AI Core

Atlas A2训练系列产品/Atlas 800I A2推理产品

注意事项

使用GetDesc前需要先调用TensorDesc.SetShapeAddr为desc指定用于储存shape信息的地址,调用GetDesc后会将shape信息写入该地址。

调用示例

示例中待解析的src_gm内存排布如下图所示:

ListTensorDesc listTensorDesc(reinterpret_cast<__gm__ void *>(src_gm));  // src_gm为待解析的gm地址
uint32_t size = listTensorDesc.GetSize();  // size = 2
auto dataPtr0 = listTensorDesc.GetDataPtr<int32_t>(0);  // 获取ptr0
auto dataPtr1 = listTensorDesc.GetDataPtr<int32_t>(1);  // 获取ptr1

__gm__ uint64_t buf[100] = {0};  // 示例中Tensor的dim为3, 此处的100表示预留足够大的空间
TensorDesc<int32_t> desc;                                                                                                          
desc.SetShapeAddr(&buf[0]);  // 为desc指定用于储存shape信息的地址
listTensorDesc.GetDesc(desc, 0);  // 获取索引0的shape信息

uint64_t dim = desc.GetDim();  // dim = 3
uint64_t idx = desc.GetIndex();  // idx = 0
uint64_t shape[3] = {0};
for(uint32_t i = 0; i < desc.GetDim(); i++) {
    shape[i] = desc.GetShape(i);  // GetShape(0) = 1, GetShape(1) = 2, GetShape(2) = 3
}
auto ptr = desc.GetDataPtr();