网络冻结样例代码
如果用户需要使用网络冻结能力,请参考以下步骤适配:
- 在模型配置文件中的freeze关键词下,配置需要冻结的网络层名称。下方示例表示冻结bottleneck.block1.layer1、bottleneck.block1.layer2、bottleneck.fc.weight三层。
freeze: bottleneck: block1: layer1 layer2 fc: weight
- 参考下方代码,将对应逻辑适配到模型代码中,实现网络冻结。
import os import stat import logging import yaml import argparse """ 需要使用网络冻结能力时, 请参考如下实现 使用方式: 1、在模型启动脚本(boot_file_path)中, 通过argparse等参数定义工具定义入参'--advanced_config' parser = argparse.ArgumentParser() parser.add_argument('--advanced_config', type=str) args = parser.parse_args() 2、获取模型启动脚本(boot_file_path)接收到的'--advanced_config'值 advanced_config_path = args.advanced_config 3、使用mindspore定义需要冻结/部分冻结的模型 model = mindspore.nn.Cell(...) 4、从advanced_config_path中解析网络冻结配置 freeze_layers = get_freeze_layers(advanced_config_path) 5、冻结网络 freeze_model(model, freeze_layers) """ CONN_WITH = '.' FREEZE_KEY = 'freeze' logging.getLogger().setLevel(logging.INFO) def freeze_model(model, freeze_layers): """ 冻结网络指定部分 :param model:网络模型 :param freeze_layers:冻结部分, 值是一个字符串列表 """ if not freeze_layers: logging.info('freeze_layers is empty, no layers in model will be frozen.') return if not isinstance(freeze_layers, list): freeze_layers = list(freeze_layers) layer_list = [] for layer in freeze_layers: layer_list.append({'layer': layer, 'exist': False}) logging.info('freeze model start.') for name, param in model.parameters_and_names(): for value in layer_list: if not isinstance(value.get('layer'), str): raise ValueError('freeze layer is not str, freeze layer: %s' % freeze_layers) if name.startswith(value.get('layer')): param.requires_grad = False value['exist'] = True for value in layer_list: if not value['exist']: logging.warning('layer: %s is not exist.', value.get('layer')) logging.info("freeze model finish.") def get_freeze_layers(model_config_path): """ 从model config配置文件中, 解析出mindspore能够识别的需要冻结的网络层 :param model_config_path: model config配置文件本地绝对路径 :return: 需要冻结的网络层名称集合 """ if model_config_path is None or not str(model_config_path): logging.warning('param model_config_path is None or empty.') return [] model_config_path = str(model_config_path) # 获取绝对路径 model_config_path = os.path.abspath(model_config_path) # 软链接校验 if os.path.islink(model_config_path): logging.warning('detect link path, stop parsing freeze configs from model config file.') return [] # 路径真实性校验 if not os.path.exists(model_config_path): logging.error('model config file path does not exist.') return [] try: content = read_file(model_config_path) except Exception as ex: logging.error('exception occurred when reading model config file, detail error message: %s', ex) raise ex if FREEZE_KEY not in content.keys(): logging.error('no [freeze] config found in model config file, no layers will be frozen.') return [] freeze_info = content.get(FREEZE_KEY) if freeze_info is None: logging.error('[freeze] attribute is empty in model config file, check model config file.') return [] if isinstance(freeze_info, str): return [freeze_info] expanded_dict = expand_dict(freeze_info) res = split_vals_with_same_key(expanded_dict) return res def read_file(model_config_path): """ 读取配置文件 """ flags = os.O_RDWR | os.O_CREAT # 允许读写, 文件不存在时新建 modes = stat.S_IWUSR | stat.S_IRUSR # 所有者读写 with os.fdopen(os.open(model_config_path, flags, modes), 'rb') as file: content = yaml.safe_load(file) return content def expand_dict(dict_info): """ 将网络冻结配置解析出的字典平铺化 :param dict_info: 平铺前的配置字典 :return: 平铺后的配置字典 """ common_prefix_dict = dict() for key_item, val_item in dict_info.items(): if key_item is None: logging.error('find [none] key from [freeze] config in model config file, ' 'config is ignored, check model config file.') continue if val_item is None: logging.error('attribute of key: [%s] is none, config is ignored, check model config file.', str(key_item)) continue if isinstance(val_item, dict): val_item = expand_dict(val_item) common_prefix_dict.update(get_prefix_dict(dict_info=val_item, prefix_str=str(key_item))) else: if str(key_item) in common_prefix_dict: logging.warning('find duplicate key from [freeze] part in model config file, check settings.') else: common_prefix_dict.update({str(key_item): [val_item for val_item in str(val_item).split(' ')]}) return common_prefix_dict def split_vals_with_same_key(expanded_dict): """ 对同一前缀下的多个子名称进行拆分 :param expanded_dict: 平铺后的字典 :return: 拆分后的完整名称列表 """ res = [] for key_item, val_item in expanded_dict.items(): for val in val_item: res.append(f'{str(key_item)}{CONN_WITH}{str(val)}') return res def get_prefix_dict(dict_info, prefix_str): """ 获取包含前缀字典(多层嵌套使用) :param dict_info: 配置字典 :param prefix_str: 前缀信息 :return: 拼接前缀后的字典 """ return {prefix_str + CONN_WITH + str(k): v for k, v in dict_info.items()}