增加args.early_stop_iteration参数,将参数传入训练主函数。
修改前:
def train(cfg, local_rank, distributed, use_tensorboard=False,): …… do_train( cfg, model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, data_loaders_val, meters ) return model def main(): parser = argparse.ArgumentParser(description="PyTorch Object Detection Training") …… parser.add_argument("--override_output_dir", default=None) args = parser.parse_args() …… model = train(cfg=cfg, local_rank=args.local_rank, distributed=args.distributed, use_tensorboard=args.use_tensorboard)
修改后:
def train(cfg, local_rank, distributed, use_tensorboard=False, early_stop_iteration=-1): …… do_train( cfg, model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, data_loaders_val, meters, early_stop_iteration=early_stop_iteration, ) return model def main(): parser = argparse.ArgumentParser(description="PyTorch Object Detection Training") …… parser.add_argument("--override_output_dir", default=None) parser.add_argument("--early_stop_iteration", type=int, default=-1) args = parser.parse_args() …… model = train(cfg=cfg, local_rank=args.local_rank, distributed=args.distributed, use_tensorboard=args.use_tensorboard, early_stop_iteration=args.early_stop_iteration)
将early_stop_iteration参数传入训练函数。
修改前:
def do_train( cfg, model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, val_data_loader=None, meters=None, zero_shot=False ): …… for iteration, (images, targets, idxs, positive_map, positive_map_eval, greenlight_map) in enumerate(data_loader, start_iter): …… arguments["iteration"] = iteration images = images.to(device) …… if val_data_loader and (iteration % checkpoint_period == 0 or iteration == max_iter): if is_main_process(): print("Evaluating") ……
修改后:
def do_train( cfg, model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, val_data_loader=None, meters=None, zero_shot=False, early_stop_iteration=-1, ): …… for iteration, (images, targets, idxs, positive_map, positive_map_eval, greenlight_map) in enumerate(data_loader, start_iter): …… arguments["iteration"] = iteration if early_stop_iteration > 0: if iteration == early_stop_iteration + 1: break images = images.to(device) …… if val_data_loader and (iteration % checkpoint_period == 0 or iteration == max_iter or iteration == early_stop_iteration): if is_main_process(): print("Evaluating") ……
修改“maskrcnn_benchmark/engine/trainer.py”脚本。
增加训练时计算实时fps的逻辑。
修改前:
def do_train( cfg, model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, val_data_loader=None, meters=None, zero_shot=False, early_stop_iteration=-1, ): …… for iteration, (images, targets, idxs, positive_map, positive_map_eval, greenlight_map) in enumerate(data_loader, start_iter): …… meters.update(time=batch_time, data=data_time)
修改后:
def do_train( cfg, model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, val_data_loader=None, meters=None, zero_shot=False, early_stop_iteration=-1, ): …… for iteration, (images, targets, idxs, positive_map, positive_map_eval, greenlight_map) in enumerate(data_loader, start_iter): …… train_fps = cfg.SOLVER.IMS_PER_BATCH / batch_time meters.update(time=batch_time, data=data_time, fps=train_fps)