大模型剪枝后蒸馏

概述

剪枝后蒸馏是大模型压缩技术的重要手段之一,主要适用于从json初始化模型结构的transform类模型,即将模型结构定义在json文件中,通过修改json文件里的配置项来初始化不同结构的模型。整体运行过程是先剪枝后蒸馏,首先通过剪枝算法裁剪掉大模型结构中一部分冗余结构的宽度和深度,生成对应场景的小模型,再通过知识蒸馏方式,使生成的小模型学习原始大模型的经验和知识,从而提升小模型的精度。

详细实现过程请参见剪枝后蒸馏调优过程,具体量化步骤请参见操作步骤

剪枝后蒸馏调优过程

大模型剪枝后蒸馏调优的具体过程如下:

操作步骤

以文本生成下游任务为例,操作步骤如下。

  1. 进入{CANN包安装路径}/ascend-toolkit/latest/tools/ascend_automl/examples/mindspore/prune/mm/caption目录。
  2. 参见opt_caption_ms_prune.md文件下载紫东.太初模型源码和数据集,并对模型脚本做适配修改。
  3. 在opt_caption_ms_prune.yml文件中根据实际情况配置以下加粗字段。

    general:
      backend: mindspore
      parallel_search: False
      ms_execute_mode: 0
      dataset_sink_mode: False
      worker:
        timeout: 7200000
    
    register:
      pkg_path: [ "/examples/models/opt-script"]        #紫东.太初源码路径
      modules:
        - module: "src.scripts.train_caption" # 模块导入
          ori_train_func: ["main"]
          script_network: "create_network_*"
    
        - module: "src.scripts.test_caption" # 模块导入
          ori_eval_func: ["main"]
          script_network: "create_network_*"
    
    pipeline: [ prebuild, finetune ]
    
    prebuild:
      pipe_step:
        type: TrainPipeStep
      dataset:
        type: RandomDataset
        image_size: 224
        batch_size: 50
        channel_size: 3
        img_len: 1000
    
      model:
        model_desc:
          teacher:
            pretrained_model_file: /examples/OPT_caption-10_11300.ckpt     #预训练权重文件
            strict: True
            model_desc:
              type: ScriptModelGen
              train:
                network:
                  type: create_network_caption_finetune
                  file_config: /examples/models/opt-script/config/ft_cap_base.json    #教师模型文本生成任务参数配置文件
                  config: &train_config
                    epochs: 10
                    start_learning_rate: 5.0e-5
                    end_learning_rate: 1.0e-7
                    use_txt_out: False
                    use_video: False
                    use_parallel: False
                    audio_dim: 512
                    use_data_fix: True
                    use_mask_fix: True
                    output_dir: <workspace_path>
                    init_loss_scale: 65536
                    loss_scale_factor: 2
                    scale_window: 1000
                    load_ckpt: True
                    save_checkpoint_steps: 5000
                    ckpt_file: ""
                    mae_ckpt_file: ""
                    sink_size: 2
                    full_batch: False
                    use_moe: False
                    hidden_size: 768
          student:
            strict: True
            pretrained_model_file: /examples/opt-scripts/OPT_caption-10_11300.ckpt   #预训练权重文件
            prune_state_dict_steps: ['prune_blocks', 'prune_bert_intra_block']
            prune_blocks_params: [{'pattern': 'opt_model\.uniter\.encoder\.encoder\.blocks\.(\d+)\.',
                                   'layer_id_map': {0: 0, 1: 2, 2: 4,  3: 6, 4: 8, 5: 10, 6: 11}}]
            model_desc:
              type: ScriptModelGen
              train:
                network:
                  type: &train_network create_network_caption_finetune
                  file_config: &train_config_file /examples/models/opt-scripts/config/ft_cap_base_student.json
                  config: *train_config
              evaluate:
                network:
                  type: &eval_network create_network_caption_finetune_eval
                  file_config: &eval_config_file /examples/models/opt-scripts/config/ft_cap_base_student.json   #学生模型文本生成任务参数配置文件
                  config: &eval_config
                    ckpt_file: <weights_file>
                    output_dir: <workspace_path>
                    use_parallel: False
                    audio_dim: 512
                    start_learning_rate: 5.0e-5
                    end_learning_rate: 1.0e-7
                    use_txt_out: False
                    use_video: False
                    use_data_fix: True
                    use_mask_fix: True
                    init_loss_scale: 65536
                    loss_scale_factor: 2
                    scale_window: 1000
                    mae_ckpt_file: ""
                    sink_size: 2
                    full_batch: False
                    use_moe: False
                    hidden_size: 768
      trainer:
        type: GetOutputShapeTrainer
        callbacks: [GetOutputShapeCallback, DistillModelBuilder]
        ori_trainer:
          type: main
          model_register_func_type: *train_network
          file_config: *train_config_file
          config: *train_config
        with_train: True
        hard_label_loss_weight: 0.5  #使用用户提供的loss时,必须设置用户提供的loss的权重
        distill_loss_manager:
          output_replace_idx: 0
          output_matches: ~
          inter_matches: [
            {
              't_module': 'uniter.encoder.encoder.blocks.11.output',
              's_module': 'uniter.encoder.encoder.blocks.6.output',
              't_output_idx': 0,
              's_output_idx': 0,
              'loss_types': [ 'HiddenMse' ],
              'loss_weights': [ 100.0 ]
            },
          ] # DistillationLosses
    
    
    finetune:
      pipe_step:
        type: TrainPipeStep
      dataset:
        ref: prebuild.dataset
    
      model:
        ref: prebuild.model
      trainer:
        ref: prebuild.trainer
        type: OriDistillTrainer
        callbacks: [DistillModelBuilder]
        output_shapes_file: "{local_base_path}/output/prebuild/model_0_output_shapes.json"
    
      evaluator:
        type: Evaluator
        host_evaluator:
          type: OriHostEvaluator
          ori_eval:
            type: main
            model_register_func_type: *eval_network
            file_config: *eval_config_file
            config: *eval_config

  4. 启动文本生成下游任务。

    mxOps prune -m taichu -c opt_caption_ms_prune.yml -d NPU

    除mxOps工具外,可使用Vega启动任务,参考如下命令:

    vega opt_caption_ms_prune.yml -d NPU