昇腾社区首页
中文
注册

模型的输入和输出数据结构准备

关于准备模型输入、输出数据结构的接口调用流程,请依次参见主要接口调用流程基本的模型执行流程准备模型推理的输入/输出数据

基本原理

pyACL提供了以下数据类型来描述模型、模型输入、模型输出以及存放数据的内存,在模型执行前,需要构造好这些数据类型,作为模型执行的输入:

  • 使用aclmdlDesc类型的数据描述模型基本信息(例如输入/输出的个数、名称、数据类型、Format、维度信息等)。

    模型加载成功后,用户可根据模型的ID,调用acl.mdl.get_desc接口获取该模型的描述信息,进而从模型的描述信息中获取模型输入/输出的个数、内存大小、维度信息、Format、数据类型等信息,可参见aclmdlDesc类型下的操作接口。

  • 使用aclmdlDataset类型的数据描述模型的输入/输出数据,模型可能存在多个输入、多个输出。

    调用aclmdlDataset类型下的操作接口添加aclDataBuffer类型的数据、获取aclDataBuffer的个数等。

  • 每个输入/输出的内存地址、内存大小用aclDataBuffer类型的数据来描述。

    调用aclDataBuffer类型下的操作接口获取内存地址、内存大小等。

示例代码

您可以从样例介绍中获取完整样例代码。

调用接口后,需增加异常处理的分支,并记录报错日志、提示日志,此处不一一列举。以下是关键步骤的代码示例,不可以直接拷贝运行,仅供参考。

# 初始化变量
ACL_MEM_MALLOC_HUGE_FIRST = 0

# 1.根据模型的ID,获取该模型的描述信息
# self.model_desc为aclmdlDesc类型
self.model_desc = acl.mdl.create_desc()
ret = acl.mdl.get_desc(self.model_desc, self.model_id)

# 2.准备模型推理的输入数据集
# 创建aclmdlDataset类型的数据,描述模型推理的输入
self.load_input_dataset = acl.mdl.create_dataset()
# 获取模型输入的数量
input_size = acl.mdl.get_num_inputs(self.model_desc)
self.input_data = []
# 循环为每个输入申请内存,并将每个输入添加到aclmdlDataset类型的数据中
for i in range(input_size):
    buffer_size = acl.mdl.get_input_size_by_index(self.model_desc, i)
    # 申请输入内存
    buffer, ret = acl.rt.malloc(buffer_size, ACL_MEM_MALLOC_HUGE_FIRST)
    data = acl.create_data_buffer(buffer, buffer_size)
    _, ret = acl.mdl.add_dataset_buffer(self.load_input_dataset, data)
    self.input_data.append({"buffer": buffer, "size": buffer_size})

# 3.准备模型推理的输出数据集
# 创建aclmdlDataset类型的数据,描述模型推理的输出
self.load_output_dataset = acl.mdl.create_dataset()
# 获取模型输出的数量
output_size = acl.mdl.get_num_inputs(self.model_desc)
self.output_data = []
# 循环为每个输出申请内存,并将每个输出添加到aclmdlDataset类型的数据中
for i in range(output_size):
    buffer_size = acl.mdl.get_input_size_by_index(self.model_desc, i)
    # 申请输出内存
    buffer, ret = acl.rt.malloc(buffer_size, ACL_MEM_MALLOC_HUGE_FIRST)
    data = acl.create_data_buffer(buffer, buffer_size)
    _, ret = acl.mdl.add_dataset_buffer(self.load_output_dataset, data)
    self.output_data.append({"buffer": buffer, "size": buffer_size})

# ......