Sample Code for Network Freezing

If you need to enable the network freezing capability, perform the following steps:

  1. Configure the name of the network layer to be frozen under freeze in the model configuration file. In following example, bottleneck.block1.layer1, bottleneck.block1.layer2, and bottleneck.fc.weight are frozen.
    freeze:
      bottleneck:
        block1:
          layer1
          layer2
        fc:
          weight
  2. Adapt the corresponding logic to the model code by referring to the following code to freeze the network.
    import os
    import stat
    import logging
    import yaml
    import argparse
    
    """
    To enable the network freezing capability, refer to the following implementation:
    
    Instructions:
    1. Define --advanced_config by using the parameter definition tool such as argparse in the model startup script (boot_file_path).
    parser = argparse.ArgumentParser()
    parser.add_argument('--advanced_config', type=str)
    args = parser.parse_args()
    
    2. Obtain the value of --advanced_config received by boot_file_path.
    advanced_config_path = args.advanced_config
    
    3. Use MindSpore to define a model that needs to be frozen or partially frozen.
    model = mindspore.nn.Cell(...)
    
    4. Parse the network freezing configuration from advanced_config_path.
    freeze_layers = get_freeze_layers(advanced_config_path)
    
    5. Freeze the network.
    freeze_model(model, freeze_layers)
    """
    
    CONN_WITH = '.'
    FREEZE_KEY = 'freeze'
    
    logging.getLogger().setLevel(logging.INFO)
    
    
    def freeze_model(model, freeze_layers):
        """
        Freeze the specified part of the network.
        :param model: network model
        :param freeze_layers: network part to be frozen. The value is a string list.
        """
        if not freeze_layers:
            logging.info('freeze_layers is empty, no layers in model will be frozen.')
            return
    
        if not isinstance(freeze_layers, list):
            freeze_layers = list(freeze_layers)
    
        layer_list = []
        for layer in freeze_layers:
            layer_list.append({'layer': layer, 'exist': False})
    
        logging.info('freeze model start.')
    
        for name, param in model.parameters_and_names():
            for value in layer_list:
                if not isinstance(value.get('layer'), str):
                    raise ValueError('freeze layer is not str, freeze layer: %s' % freeze_layers)
                if name.startswith(value.get('layer')):
                    param.requires_grad = False
                    value['exist'] = True
        for value in layer_list:
            if not value['exist']:
                logging.warning('layer: %s is not exist.', value.get('layer'))
    
        logging.info("freeze model finish.")
    
    
    def get_freeze_layers(model_config_path):
        """
         Parse the network layer to be frozen that can be identified by MindSpore from the model configuration file.
        :param model_config_path: local absolute path of the model configuration file
        :return: names of the network layer to be frozen
        """
        if model_config_path is None or not str(model_config_path):
            logging.warning('param model_config_path is None or empty.')
            return []
    
        model_config_path = str(model_config_path)
    
        # Obtain the absolute path.
        model_config_path = os.path.abspath(model_config_path)
    
        # Verify the soft link.
        if os.path.islink(model_config_path):
            logging.warning('detect link path, stop parsing freeze configs from model config file.')
            return []
    
        # Verify the path authenticity.
        if not os.path.exists(model_config_path):
            logging.error('model config file path does not exist.')
            return []
    
        try:
            content = read_file(model_config_path)
        except Exception as ex:
            logging.error('exception occurred when reading model config file, detail error message: %s', ex)
            raise ex
    
        if FREEZE_KEY not in content.keys():
            logging.error('no [freeze] config found in model config file, no layers will be frozen.')
            return []
    
        freeze_info = content.get(FREEZE_KEY)
        if freeze_info is None:
            logging.error('[freeze] attribute is empty in model config file, check model config file.')
            return []
    
        if isinstance(freeze_info, str):
            return [freeze_info]
    
        expanded_dict = expand_dict(freeze_info)
        res = split_vals_with_same_key(expanded_dict)
        return res
    
    
    def read_file(model_config_path):
        """
        Read the configuration file.
        """
        flags = os.O_RDWR | os.O_CREAT  # The file can be read and written. If it does not exist, create it.
        modes = stat.S_IWUSR | stat.S_IRUSR  # Only file owner can read and write the file.
    
        with os.fdopen(os.open(model_config_path, flags, modes), 'rb') as file:
            content = yaml.safe_load(file)
    
        return content
    
    
    def expand_dict(dict_info):
        """
        Tile the dictionary parsed from the network freezing configuration.
        :param dict_info: configuration dictionary before tiling
        :return: tiled configuration dictionary
        """
        common_prefix_dict = dict()
    
        for key_item, val_item in dict_info.items():
            if key_item is None:
                logging.error('find [none] key from [freeze] config in model config file, '
                             'config is ignored, check model config file.')
                continue
    
            if val_item is None:
                logging.error('attribute of key: [%s] is none, config is ignored, check model config file.', str(key_item))
                continue
    
            if isinstance(val_item, dict):
                val_item = expand_dict(val_item)
                common_prefix_dict.update(get_prefix_dict(dict_info=val_item, prefix_str=str(key_item)))
            else:
                if str(key_item) in common_prefix_dict:
                    logging.warning('find duplicate key from [freeze] part in model config file, check settings.')
                else:
                    common_prefix_dict.update({str(key_item): [val_item for val_item in str(val_item).split(' ')]})
    
        return common_prefix_dict
    
    
    def split_vals_with_same_key(expanded_dict):
        """
        Split multiple sub-names with the same prefix.
        :param expanded_dict: tiled dictionary
        :return: complete name list after splitting
        """
        res = []
        for key_item, val_item in expanded_dict.items():
            for val in val_item:
                res.append(f'{str(key_item)}{CONN_WITH}{str(val)}')
        return res
    
    
    def get_prefix_dict(dict_info, prefix_str):
        """
        Obtain the dictionary that contains prefixes for multi-layer nesting.
        :param dict_info: configuration dictionary
        :param prefix_str: prefix information
        :return: dictionary after prefixes are concatenated
        """
        return {prefix_str + CONN_WITH + str(k): v for k, v in dict_info.items()}