昇腾社区首页
中文
注册

前向推理

  • 对于前向推理部分,需要额外提供一个最大迭代次数“max_iters”参数。由于每个模型计算“max_iters”参数的方式存在区别,因此,前向推理的infer函数必须在子类中重写或者提前计算好该参数。如果在子类中重写infer函数,请务必保证输入的参数格式与基类的infer完全一致,计算得到该参数后,再调用super().infer()来执行父类中的infer函数逻辑。

    示例如下:

     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    def infer(self, mm_inputs, batch_size, max_output_length, ignore_eos, max_iters=None, **kwargs):
            input_texts = mm_inputs.input_texts
            image_path_list = mm_inputs.image_path
            video_path_list = mm_inputs.video_path
            path_list = image_path_list if image_path_list else video_path_list
            if len(input_texts) != len(image_path_list):
                raise RuntimeError("input_text length must equal input_images length")
            if not ENV.profiling_enable:
                if self.max_batch_size > 0:
                    max_iters = math.ceil(len(mm_inputs.image_path) / self.max_batch_size)
                else:
                    raise RuntimeError("f{self.max_batch_size} max_batch_size should > 0, please check")
            return super().infer(mm_inputs, batch_size, max_output_length, ignore_eos, max_iters=max_iters)
    
  • 如果有四种模态(文本、图像、音频或视频)输入之外的其他定制参数,则需要重写“prepare_request”参数,可参考{llm_path}/examples/models/multimodal_runner.py中“prepare_request”函数的实现。
  • 如果需要对精度测试结果进行保存,则需要在子类中重写precision_save方法来保存结果。示例如下:
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    def precision_save(self, precision_inputs, **kwargs):
            all_input_texts = precision_inputs.all_input_texts
            all_generate_text_list = precision_inputs.all_generate_text_list
            image_file_list = precision_inputs.image_file_list
            video_file_list = precision_inputs.video_file_list
            file_list = image_file_list if image_file_list else video_file_list
            answer_pairs = {}
            if not file_list:
                raise ValueError("Both image_file_list and video_file_list are empty.")
            if len(all_input_texts) != len(file_list):
                raise ValueError(f"Mismatched lengths between \
                    all_input_texts={all_input_texts} and file_list={file_list}")
            for text_index in range(len(all_input_texts)):
                image_answer_pairs[file_list[text_index]] = all_generate_text_list[text_index]
                image_answer_pairs = dict(sorted(answer_pairs.items()))
            super().precision_save(precision_inputs, answer_pairs=answer_pairs)