通过继承或者重写初始化、warm_up以及前向推理等方法,实现了适配基类“MultimodalPARunner”后,需要在Main函数中实现基本的路径解析和输入准备,以InternVL2.5为例:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 | if __name__ == '__main__': args = parse_arguments() rank = ENV.rank local_rank = ENV.local_rank world_size = ENV.world_size image_or_video_path = standardize_path(args.image_or_video_path) check_file_safety(image_or_video_path, 'r') file_name = safe_listdir(image_or_video_path) file_length = len(file_name) input_dict = { 'rank': rank, 'world_size': world_size, 'local_rank': local_rank, 'perf_file': PERF_FILE, **vars(args) } if is_image_path(image_or_video_path): image_path = [os.path.join(image_or_video_path, f) for f in file_name] video_path = None input_dict['image_path'] = image_path texts = args.input_texts_for_image elif is_video_path(image_or_video_path): video_path = [os.path.join(image_or_video_path, f) for f in file_name] image_path = None input_dict['video_path'] = video_path texts = args.input_texts_for_video else: logger.error("Unsupported media type, it should be a video or image, please check your input.", ErrorCode.ATB_MODELS_PARAM_OUT_OF_RANGE) raise KeyError("Unsupported media type, it should be a video or image, please check your input.") if len(texts) > file_length: raise ValueError(f"The number of input texts is greater than the number of files.") texts.extend([texts[-1]] * (file_length - len(texts))) input_dict['input_texts'] = texts pa_runner = InternvlRunner(**input_dict) if image_path: image_length = len(image_path) remainder = image_length % args.max_batch_size if remainder != 0: num_to_add = args.max_batch_size - remainder image_path.extend([image_path[-1]] * num_to_add) texts.extend([texts[-1]] * num_to_add) elif video_path: video_length = len(video_path) remainder = video_length % args.max_batch_size if remainder != 0: num_to_add = args.max_batch_size - remainder video_path.extend([video_path[-1]] * num_to_add) texts.extend([texts[-1]] * num_to_add) print_log(rank, logger.info, f'pa_runner: {pa_runner}') infer_params = { "mm_inputs": MultimodalInput(texts, image_path, video_path, None), "batch_size": args.max_batch_size, "max_output_length": args.max_output_length, "ignore_eos": args.ignore_eos, } pa_runner.warm_up() generate_texts, token_nums, latency = pa_runner.infer(**infer_params) for i, generate_text in enumerate(generate_texts): print_log(rank, logger.info, f'Answer[{i}]: {generate_text}') print_log(rank, logger.info, f'Generate[{i}] token num: {token_nums[i]}') print_log(rank, logger.info, f"Latency: {latency}") |