其他修改

增加设置迭代步数和提前停止训练逻辑

  1. 修改“tools/train_net.py”脚本。

    增加args.early_stop_iteration参数,将参数传入训练主函数。

    修改前:

    def train(cfg, local_rank, distributed, use_tensorboard=False,):
    ……
        do_train(
            cfg,
            model,
            data_loader,
            optimizer,
            scheduler,
            checkpointer,
            device,
            checkpoint_period,
            arguments,
            data_loaders_val,
            meters
        )
    
    return model
    
    def main():
    parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")
    ……
    parser.add_argument("--override_output_dir", default=None)
    
    args = parser.parse_args()
    ……
        model = train(cfg=cfg,
                      local_rank=args.local_rank,
                      distributed=args.distributed,
                      use_tensorboard=args.use_tensorboard)

    修改后:

    def train(cfg, local_rank, distributed, use_tensorboard=False, early_stop_iteration=-1):
    ……
        do_train(
            cfg,
            model,
            data_loader,
            optimizer,
            scheduler,
            checkpointer,
            device,
            checkpoint_period,
            arguments,
            data_loaders_val,
            meters,
            early_stop_iteration=early_stop_iteration,
        )
    
    return model
    
    def main():
    parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")
    ……
    parser.add_argument("--override_output_dir", default=None)
    parser.add_argument("--early_stop_iteration", type=int, default=-1)
    
    args = parser.parse_args()
    ……
    model = train(cfg=cfg,
                  local_rank=args.local_rank,
                  distributed=args.distributed,
                  use_tensorboard=args.use_tensorboard,
                  early_stop_iteration=args.early_stop_iteration)

  2. 修改“maskrcnn_benchmark/engine/trainer.py”脚本。

    将early_stop_iteration参数传入训练函数。

    修改前:

    def do_train(
            cfg,
            model,
            data_loader,
            optimizer,
            scheduler,
            checkpointer,
            device,
            checkpoint_period,
            arguments,
            val_data_loader=None,
            meters=None,
            zero_shot=False
    ):
    ……
    for iteration, (images, targets, idxs, positive_map, positive_map_eval, greenlight_map) in enumerate(data_loader, start_iter):
        ……
            arguments["iteration"] = iteration
    
    images = images.to(device)
    ……
            if val_data_loader and (iteration % checkpoint_period == 0 or iteration == max_iter):
                if is_main_process():
                    print("Evaluating")
                ……

    修改后:

    def do_train(
            cfg,
            model,
            data_loader,
            optimizer,
            scheduler,
            checkpointer,
            device,
            checkpoint_period,
            arguments,
            val_data_loader=None,
            meters=None,
            zero_shot=False,
            early_stop_iteration=-1,
    ):
    ……
    for iteration, (images, targets, idxs, positive_map, positive_map_eval, greenlight_map) in enumerate(data_loader, start_iter):
        ……
            arguments["iteration"] = iteration
            if early_stop_iteration > 0:
                if iteration == early_stop_iteration + 1:
                    break
    
    images = images.to(device)
    ……
    
            if val_data_loader and (iteration % checkpoint_period == 0 or iteration == max_iter or
                                    iteration == early_stop_iteration):
    if is_main_process():
                    print("Evaluating")
                ……

增加训练时计算实时fps的逻辑

修改“maskrcnn_benchmark/engine/trainer.py”脚本。

增加训练时计算实时fps的逻辑。

修改前:

def do_train(
        cfg,
        model,
        data_loader,
        optimizer,
        scheduler,
        checkpointer,
        device,
        checkpoint_period,
        arguments,
        val_data_loader=None,
        meters=None,
        zero_shot=False,
        early_stop_iteration=-1,
):
……
for iteration, (images, targets, idxs, positive_map, positive_map_eval, greenlight_map) in enumerate(data_loader, start_iter):
    ……
    meters.update(time=batch_time, data=data_time)

修改后:

def do_train(
        cfg,
        model,
        data_loader,
        optimizer,
        scheduler,
        checkpointer,
        device,
        checkpoint_period,
        arguments,
        val_data_loader=None,
        meters=None,
        zero_shot=False,
        early_stop_iteration=-1,
):
……
for iteration, (images, targets, idxs, positive_map, positive_map_eval, greenlight_map) in enumerate(data_loader, start_iter):
    ……
train_fps = cfg.SOLVER.IMS_PER_BATCH / batch_time
meters.update(time=batch_time, data=data_time, fps=train_fps)