增加离线分词器和文本编码器路径变量。
_C.MODEL.LANGUAGE_BACKBONE.TOKENIZER_PATH = "" _C.MODEL.LANGUAGE_BACKBONE.MODEL_PATH = ""
通过新增变量改变加载分词器的逻辑。
修改前:
def make_data_loader(cfg, is_train=True, is_distributed=False, num_replicas=None, rank=None, start_iter=0): if cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "clip": …… else: extra_args['tokenizer'] = AutoTokenizer.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE)
修改后:
def make_data_loader(cfg, is_train=True, is_distributed=False, num_replicas=None, rank=None, start_iter=0): if cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "clip": …… else: if cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_PATH: extra_args['tokenizer'] = AutoTokenizer.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_PATH) else: extra_args['tokenizer'] = AutoTokenizer.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE)
通过新增变量改变加载分词器的逻辑。
修改前:
def create_queries_and_maps(labels, label_list, additional_labels = None, cfg = None): if cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "bert-base-uncased": tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
修改后:
def create_queries_and_maps(labels, label_list, additional_labels = None, cfg = None): if cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "bert-base-uncased": if cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_PATH: tokenizer = AutoTokenizer.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_PATH) else: tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
通过新增变量改变加载分词器的逻辑。
修改前:
class GeneralizedVLRCNN(nn.Module): def __init__(self, cfg): # language encoder if cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "clip": …… else: self.tokenizer = AutoTokenizer.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE)
修改后:
class GeneralizedVLRCNN(nn.Module): def __init__(self, cfg): # language encoder if cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "clip": …… else: if cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_PATH: self.tokenizer = AutoTokenizer.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_PATH) else: self.tokenizer = AutoTokenizer.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE)
通过新增变量改变加载文本编码器的逻辑。
修改前:
class BertEncoder(nn.Module): def __init__(self, cfg): …… if self.bert_name == "bert-base-uncased": config = BertConfig.from_pretrained(self.bert_name) config.gradient_checkpointing = self.cfg.MODEL.LANGUAGE_BACKBONE.USE_CHECKPOINT self.model = BertModel.from_pretrained(self.bert_name, add_pooling_layer=False, config=config) self.language_dim = 768
修改后:
class BertEncoder(nn.Module): def __init__(self, cfg): …… if self.bert_name == "bert-base-uncased": if cfg.MODEL.LANGUAGE_BACKBONE.MODEL_PATH: config = BertConfig.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.MODEL_PATH) config.gradient_checkpointing = self.cfg.MODEL.LANGUAGE_BACKBONE.USE_CHECKPOINT print('config: ', config, flush=True) self.model = BertModel.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.MODEL_PATH, add_pooling_layer=False, config=config) else: config = BertConfig.from_pretrained(self.bert_name) config.gradient_checkpointing = self.cfg.MODEL.LANGUAGE_BACKBONE.USE_CHECKPOINT self.model = BertModel.from_pretrained(self.bert_name, add_pooling_layer=False, config=config) self.language_dim = 768
通过新增变量改变加载分词器的逻辑。
修改前:
class ATSSLossComputation(torch.nn.Module): def __init__(self, cfg, box_coder): …… self.lang = cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE if cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "clip": …… else: self.tokenizer = AutoTokenizer.from_pretrained(self.lang)
修改后:
class ATSSLossComputation(torch.nn.Module): def __init__(self, cfg, box_coder): …… self.lang = cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE if cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "clip": …… else: if cfg.MODEL.LANGUAGE_BACKBONE.MODEL_PATH: self.tokenizer = AutoTokenizer.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.MODEL_PATH) else: self.tokenizer = AutoTokenizer.from_pretrained(self.lang)
通过新增变量改变加载文本编码器的逻辑。
修改前:
class VLDyHead(torch.nn.Module): def __init__(self, cfg): …… if cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE == "bert-base-uncased": lang_cfg = BertConfig.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE)
修改后:
class VLDyHead(torch.nn.Module): def __init__(self, cfg): …… if cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE == "bert-base-uncased": if cfg.MODEL.LANGUAGE_BACKBONE.MODEL_PATH: lang_cfg = BertConfig.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.MODEL_PATH) else: lang_cfg = BertConfig.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE)