ListTensorDesc
功能说明
ListTensorDesc用来解析符合以下内存排布格式的数据, 并在kernel侧根据索引获取储存对应数据的地址及shape信息:
定义原型
class ListTensorDesc { ListTensorDesc(); ListTensorDesc(__gm__ void* data, uint32_t length = 0xffffffff, uint32_t shapeSize = 0xffffffff); void Init(__gm__ void* data, uint32_t length = 0xffffffff, uint32_t shapeSize = 0xffffffff); template<class T> void GetDesc(TensorDesc<T>& desc, uint32_t index); template<class T> T* GetDataPtr(uint32_t index); uint32_t GetSize(); }
函数说明
函数名称 |
入参说明 |
含义 |
---|---|---|
ListTensorDesc |
- |
默认构造函数,需配合Init函数使用。 |
ListTensorDesc |
data:待解析数据的首地址 length:待解析内存的长度 shapeSize:数据指针的个数 length和shapeSize仅用于校验,不填写时不进行校验 |
ListTensorDesc类的构造函数, 用于解析对应的内存排布。 |
Init |
data:待解析数据的首地址 length:待解析内存的长度 shapeSize:数据指针的个数 length和shapeSize仅用于校验,不填写时不进行校验 |
初始化函数, 用于解析对应的内存排布。 |
GetDesc |
desc:出参, 解析后的Tensor描述信息 index:索引值 |
根据index获得对应的Tensor描述信息。 使用GetDesc前需要先调用TensorDesc.SetShapeAddr为desc指定用于储存shape信息的地址,调用GetDesc后会将shape信息写入该地址。 Atlas推理系列产品(Ascend 310P处理器)AI Core支持该功能 Atlas 训练系列产品不支持该功能 Atlas A2训练系列产品/Atlas 800I A2推理产品支持该功能 Atlas 200/500 A2推理产品不支持该功能 |
GetDataPtr |
index:索引值 |
根据index获取储存对应数据的地址。 |
GetSize |
- |
获取ListTensor中包含的数据指针的个数。 |
支持的型号
Atlas推理系列产品(Ascend 310P处理器)AI Core
Atlas A2训练系列产品/Atlas 800I A2推理产品
调用示例
示例中待解析的src_gm内存排布如下图所示:
AscendC::ListTensorDesc listTensorDesc(reinterpret_cast<__gm__ void *>(srcGm)); // src_gm为待解析的gm地址 uint32_t size = listTensorDesc.GetSize(); // size = 2 auto dataPtr0 = listTensorDesc.GetDataPtr<int32_t>(0); // 获取ptr0 auto dataPtr1 = listTensorDesc.GetDataPtr<int32_t>(1); // 获取ptr1 uint64_t buf[100] = {0}; // 示例中Tensor的dim为3, 此处的100表示预留足够大的空间 AscendC::TensorDesc<int32_t> desc; desc.SetShapeAddr(buf); // 为desc指定用于储存shape信息的地址 listTensorDesc.GetDesc(desc, 0); // 获取索引0的shape信息 uint64_t dim = desc.GetDim(); // dim = 3 uint64_t idx = desc.GetIndex(); // idx = 0 uint64_t shape[3] = {0}; for (uint32_t i = 0; i < desc.GetDim(); i++) { shape[i] = desc.GetShape(i); // GetShape(0) = 1, GetShape(1) = 2, GetShape(2) = 3 } auto ptr = desc.GetDataPtr();