昇腾社区首页
中文
注册

网络冻结样例代码

如果用户需要使用网络冻结能力,请参考以下步骤适配:

  1. 在模型配置文件中的freeze关键词下,配置需要冻结的网络层名称。下方示例表示冻结bottleneck.block1.layer1、bottleneck.block1.layer2、bottleneck.fc.weight三层。
    freeze:
      bottleneck:
        block1:
          layer1
          layer2
        fc:
          weight
  2. 参考下方代码,将对应逻辑适配到模型代码中,实现网络冻结。
    import os
    import stat
    import logging
    import yaml
    import argparse
    
    """
    需要使用网络冻结能力时, 请参考如下实现
    
    使用方式:
    1、在模型启动脚本(boot_file_path)中, 通过argparse等参数定义工具定义入参'--advanced_config'
    parser = argparse.ArgumentParser()
    parser.add_argument('--advanced_config', type=str)
    args = parser.parse_args()
    
    2、获取模型启动脚本(boot_file_path)接收到的'--advanced_config'值
    advanced_config_path = args.advanced_config
    
    3、使用mindspore定义需要冻结/部分冻结的模型
    model = mindspore.nn.Cell(...)
    
    4、从advanced_config_path中解析网络冻结配置
    freeze_layers = get_freeze_layers(advanced_config_path)
    
    5、冻结网络
    freeze_model(model, freeze_layers)
    """
    
    CONN_WITH = '.'
    FREEZE_KEY = 'freeze'
    
    logging.getLogger().setLevel(logging.INFO)
    
    
    def freeze_model(model, freeze_layers):
        """
        冻结网络指定部分
        :param model:网络模型
        :param freeze_layers:冻结部分, 值是一个字符串列表
        """
        if not freeze_layers:
            logging.info('freeze_layers is empty, no layers in model will be frozen.')
            return
    
        if not isinstance(freeze_layers, list):
            freeze_layers = list(freeze_layers)
    
        layer_list = []
        for layer in freeze_layers:
            layer_list.append({'layer': layer, 'exist': False})
    
        logging.info('freeze model start.')
    
        for name, param in model.parameters_and_names():
            for value in layer_list:
                if not isinstance(value.get('layer'), str):
                    raise ValueError('freeze layer is not str, freeze layer: %s' % freeze_layers)
                if name.startswith(value.get('layer')):
                    param.requires_grad = False
                    value['exist'] = True
        for value in layer_list:
            if not value['exist']:
                logging.warning('layer: %s is not exist.', value.get('layer'))
    
        logging.info("freeze model finish.")
    
    
    def get_freeze_layers(model_config_path):
        """
        从model config配置文件中, 解析出mindspore能够识别的需要冻结的网络层
        :param model_config_path: model config配置文件本地绝对路径
        :return: 需要冻结的网络层名称集合
        """
        if model_config_path is None or not str(model_config_path):
            logging.warning('param model_config_path is None or empty.')
            return []
    
        model_config_path = str(model_config_path)
    
        # 获取绝对路径
        model_config_path = os.path.abspath(model_config_path)
    
        # 软链接校验
        if os.path.islink(model_config_path):
            logging.warning('detect link path, stop parsing freeze configs from model config file.')
            return []
    
        # 路径真实性校验
        if not os.path.exists(model_config_path):
            logging.error('model config file path does not exist.')
            return []
    
        try:
            content = read_file(model_config_path)
        except Exception as ex:
            logging.error('exception occurred when reading model config file, detail error message: %s', ex)
            raise ex
    
        if FREEZE_KEY not in content.keys():
            logging.error('no [freeze] config found in model config file, no layers will be frozen.')
            return []
    
        freeze_info = content.get(FREEZE_KEY)
        if freeze_info is None:
            logging.error('[freeze] attribute is empty in model config file, check model config file.')
            return []
    
        if isinstance(freeze_info, str):
            return [freeze_info]
    
        expanded_dict = expand_dict(freeze_info)
        res = split_vals_with_same_key(expanded_dict)
        return res
    
    
    def read_file(model_config_path):
        """
        读取配置文件
        """
        flags = os.O_RDWR | os.O_CREAT  # 允许读写, 文件不存在时新建
        modes = stat.S_IWUSR | stat.S_IRUSR  # 所有者读写
    
        with os.fdopen(os.open(model_config_path, flags, modes), 'rb') as file:
            content = yaml.safe_load(file)
    
        return content
    
    
    def expand_dict(dict_info):
        """
        将网络冻结配置解析出的字典平铺化
        :param dict_info: 平铺前的配置字典
        :return: 平铺后的配置字典
        """
        common_prefix_dict = dict()
    
        for key_item, val_item in dict_info.items():
            if key_item is None:
                logging.error('find [none] key from [freeze] config in model config file, '
                             'config is ignored, check model config file.')
                continue
    
            if val_item is None:
                logging.error('attribute of key: [%s] is none, config is ignored, check model config file.', str(key_item))
                continue
    
            if isinstance(val_item, dict):
                val_item = expand_dict(val_item)
                common_prefix_dict.update(get_prefix_dict(dict_info=val_item, prefix_str=str(key_item)))
            else:
                if str(key_item) in common_prefix_dict:
                    logging.warning('find duplicate key from [freeze] part in model config file, check settings.')
                else:
                    common_prefix_dict.update({str(key_item): [val_item for val_item in str(val_item).split(' ')]})
    
        return common_prefix_dict
    
    
    def split_vals_with_same_key(expanded_dict):
        """
        对同一前缀下的多个子名称进行拆分
        :param expanded_dict: 平铺后的字典
        :return: 拆分后的完整名称列表
        """
        res = []
        for key_item, val_item in expanded_dict.items():
            for val in val_item:
                res.append(f'{str(key_item)}{CONN_WITH}{str(val)}')
        return res
    
    
    def get_prefix_dict(dict_info, prefix_str):
        """
        获取包含前缀字典(多层嵌套使用)
        :param dict_info: 配置字典
        :param prefix_str: 前缀信息
        :return: 拼接前缀后的字典
        """
        return {prefix_str + CONN_WITH + str(k): v for k, v in dict_info.items()}