Sparsity Process

This section describes the sparse layers supported by manual sparsity, and API call sequence and example.

Currently, AMCT supports only retraining-based filter-level sparsity. For the sparsity example, see "Filter-level sparsity" in Sample List. The layers that support filter-level sparsity and their restrictions are listed as follows.

Table 1 Layers that support filter-level sparsity as well as their restrictions

Technique

Supported Layer Type

Restrictions

Filter-Level Sparsity

torch.nn.Linear

Layers sharing the weight and bias parameters do not support sparsity.

torch.nn.Conv2d

  • Layers sharing the weight and bias parameters do not support sparsity.
  • Depthwise processing supports only passive sparsity (groups = in_channels).
  • The shape of the input data must be (N, Cin, Hin, Win).

API Call Sequence

Figure 1 shows the API call sequence for filter-level sparsity.

Figure 1 API call sequence for filter-level sparsity
The user implements the operations in blue, while those in gray are implemented by using AMCT APIs.
  1. Build a source PyTorch model, call the create_prune_retrain_model API to modify the model, and insert the filter-level sparsity mask operator into the graph structure to prune model parameters.
  2. Train the modified model until the accuracy meets your requirement. If the training process is interrupted, call the restore_prune_retrain_model API to prune the source model again based on the sparsity record file and perform QAT until the accuracy meets your requirement.
  3. Generate the .pth file that meets the accuracy requirement based on the final retrained filter-level sparsity model. Alternatively, call the save_prune_retrain_model API to generate the final ONNX fake-quantized model and deployable model.

Calling Example

  1. Take the following steps to get started. Update the sample code based on your situation.
  2. Tweak the arguments passed to AMCT API calls as required. Sparsity relies on the user training result. Ensure that a PyTorch training script that yields satisfactory training accuracy is available.
  1. Import the AMCT package and set the log level (see Post-installation Actions for details).
    1
    import amct_pytorch as amct
    
  2. (Optional) Run inference on the source model in the PyTorch environment based on the test dataset to validate the inference script and environment setup. (Update the sample code based on your situation.)

    This step is recommended as it guarantees a properly functioning source model for inference with acceptable accuracy. You can use a subset from the test dataset to improve the efficiency.

    1
    2
    3
    ori_model.load()
    # Test the model.
    user_test_model(ori_model, test_data, test_iterations)
    
  3. Call AMCT to perform training with the sparsity operator.
    1. Modify the original model and insert the filter-level sparsity mask operator into the graph structure.
      Before performing this step, restore the already-trained parameters, for example, ori_model.load() in 2.
      1
      2
      3
      4
      5
      6
      simple_cfg = './retrain.cfg'
      record_file = './tmp/record.txt'
      prune_retrain_model = amct.create_prune_retrain_model(model=ori_model,
                                      input_data=ori_model_input_data,
                                      config_defination=simple_cfg,
                                      record_file=record_file)
      
    2. Implement gradient descent optimization on the modified graph and train the graph on the training dataset. (Update the sample code based on your situation.)
      1. Implement gradient descent optimization on the modified graph.
        Perform this step after the model is trimmed.
        1
        optimizer = user_create_optimizer(prune_retrain_model)
        
      2. Restore the model from existing checkpoints and train the model.

        Note: Restore model parameters from the trained checkpoints before training.

        1
        2
        quant_pth = './ckpt/user_model'
        user_train_model(optimizer, prune_retrain_model, train_data)
        
    3. (Optional) If save_prune_retrain_model is called, refer to this step. If the model is saved as a .pth file, skip this step.
      Save the model to implement filter-level sparsity.
      1
      2
      3
      4
      prune_retrain_model = amct.save_prune_retrain_model(
           model=pruned_retrain_model,
           save_path=save_path,
           input_data=input_data)
      
  4. (Optional) Run inference on the quantized model (quant_model) in the ONNX Runtime environment based on the test dataset (test_data) to test the accuracy. Update the sample code based on your situation.

    Check the accuracy loss of the fake-quantized model by comparing with that of the source model (see 2).

    1
    2
    prune_retrain_model = './results/user_model_fake_prune_model.onnx'
    user_do_inference_onnx(prune_retrain_model, test_data, test_iterations)
    

If the training process is interrupted, restore data from the checkpoints to resume the training.

  1. Import the AMCT package and set the log level (see Post-installation Actions for details).
    1
    import amct_pytorch as amct
    
  1. Prepare a source model.
    1
    ori_model= user_create_model()
    
  2. Call AMCT to resume the QAT process.
    1. Modify the model, insert the filter-level sparsity mask operator into the graph structure, and save the model as a new prune_model.
       1
       2
       3
       4
       5
       6
       7
       8
       9
      10
      11
      12
      model = ori_model
      input_data = ori_model_input_data
      record_file = './tmp/record.txt'
      config_defination = './prune_cfg.cfg'
      save_pth_path = /your/path/to/save/tmp.pth
      model.load_state_dict(torch.load(state_dict_path))
      prune_retrain_model = amct.restore_prune_retrain_model(model=ori_model,
                                                             input_data=ori_model_input_data,
                                                             record_file=record_file,
                                                             config_defination='./prune_cfg.cfg',
                                                             save_pth_path=/your/path/to/save/tmp.pth,
                                                             'state_dict')
      
    2. Implement gradient descent optimization on the modified model, restore parameters from the checkpoints, and train the model on the training dataset. (Update the sample code based on your situation.)
      1. Restore the model parameters from the checkpoints after sparsity.
        1
        2
        quant_pth = './ckpt/user_prune_model'
        user_train_model(optimizer, prune_retrain_model, train_data)
        
      2. Implement gradient descent optimization on the modified graph.
        Perform this step after model parameters are restored.
        1
        optimizer = user_create_optimizer(prune_retrain_model)
        
      3. Restore the model from existing checkpoints and train the model.

        Note: Restore model parameters from the trained checkpoints before training.

        1
        user_train_model(optimizer, prune_retrain_model, train_data)
        
    3. (Optional) If save_prune_retrain_model is called, refer to this step. If the model is saved as a .pth file, skip this step.
      Save the model to implement filter-level sparsity.
      1
      2
      3
      4
      prune_retrain_model = amct.save_prune_retrain_model(
           model=pruned_retrain_model,
           save_path=save_path,
           input_data=input_data)
      
  3. (Optional) Run inference on the quantized model (quant_model) in the ONNX Runtime environment based on the test dataset (test_data) to test the accuracy. Update the sample code based on your situation.

    Check the accuracy loss of the fake-quantized model by comparing with that of the source model (see 2).

    1
    2
    prune_retrain_model = './results/user_model_fake_prune_model.onnx'
    user_do_inference_onnx(prune_retrain_model, test_data, test_iterations)
    

Follow-up Operations

If the sparsified model is output in .pth format, refer to this section. If save_prune_retrain_model is called, skip this section.

The output .pth model is not directly deployable for inference. Before using the ATC tool to convert the model, you need to convert the .pth model into an ONNX network model, or call save_prune_retrain_model to save the .pth model as the final ONNX fake-quantized model and deployable model. The following is an example of calling save_prune_retrain_model:

1
2
3
4
prune_retrain_model = amct.10.6.3-save_prune_retrain_model(
     model=pruned_retrain_model,
     save_path=save_path,
     input_data=input_data)