1 2 3 4 5 6 7 8 9 10 11 12 13 | def infer(self, mm_inputs, batch_size, max_output_length, ignore_eos, max_iters=None, **kwargs): input_texts = mm_inputs.input_texts image_path_list = mm_inputs.image_path video_path_list = mm_inputs.video_path path_list = image_path_list if image_path_list else video_path_list if len(input_texts) != len(image_path_list): raise RuntimeError("input_text length must equal input_images length") if not ENV.profiling_enable: if self.max_batch_size > 0: max_iters = math.ceil(len(mm_inputs.image_path) / self.max_batch_size) else: raise RuntimeError("f{self.max_batch_size} max_batch_size should > 0, please check") return super().infer(mm_inputs, batch_size, max_output_length, ignore_eos, max_iters=max_iters) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | def precision_save(self, precision_inputs, **kwargs): all_input_texts = precision_inputs.all_input_texts all_generate_text_list = precision_inputs.all_generate_text_list image_file_list = precision_inputs.image_file_list video_file_list = precision_inputs.video_file_list file_list = image_file_list if image_file_list else video_file_list answer_pairs = {} if not file_list: raise ValueError("Both image_file_list and video_file_list are empty.") if len(all_input_texts) != len(file_list): raise ValueError(f"Mismatched lengths between \ all_input_texts={all_input_texts} and file_list={file_list}") for text_index in range(len(all_input_texts)): image_answer_pairs[file_list[text_index]] = all_generate_text_list[text_index] image_answer_pairs = dict(sorted(answer_pairs.items())) super().precision_save(precision_inputs, answer_pairs=answer_pairs) |