Sparsity Process
This section describes the layers supported by 2:4 structured sparsity, and API call sequence and example.
Due to hardware restrictions, the Atlas 200/300/500 Inference Product, and Atlas Training Series Product do not support the 2:4 structured sparsity feature. Enabling this feature obtains few performance benefits.
AMCT supports retraining-based 2:4 structured sparsity. The layers that support 2:4 structured sparsity and their restrictions are as follows. For the sparsity example, see Sample List.
Technique |
Supported Layer Type |
Restrictions |
|---|---|---|
2:4 Structured Sparsity |
torch.nn.Linear |
Layers sharing the weight do not support sparsity. |
torch.nn.Conv2d |
|
|
torch:ConvTranspose2d |
|
API Call Sequence
Figure 1 shows the API call sequence for 2:4 structured sparsity.
- Build a source PyTorch model, call the create_prune_retrain_model API to modify the model, and replace the to-be-sparsified operator with the 2:4 structured sparsity operator.
- Train the modified model until the accuracy meets your requirement. If the training process is interrupted, call the restore_prune_retrain_model API to prune the source model again based on the sparsity record file and perform QAT until the accuracy meets your requirement.
- Restore the replaced operator and prune the weight data by calling the save_prune_retrain_model API based on the retrained 2:4 structured sparsity model, thus to generate the final ONNX fake-quantized model and deployable model.
Examples
- Take the following steps to get started. Update the sample code based on your situation.
- Tweak the arguments passed to AMCT API calls as required. Sparsity relies on the user training result. Ensure that a PyTorch training script that yields satisfactory training accuracy is available.
- (Optional) Run inference on the source model in the PyTorch environment based on the test dataset to validate the inference script and environment setup. (Update the sample code based on your situation.)
This step is recommended as it guarantees a properly functioning source model for inference with acceptable accuracy. You can use a subset from the test dataset to improve the efficiency.
1 2 3
ori_model.load() # Test the model. user_test_model(ori_model, test_data, test_iterations)
- Call AMCT to execute the 2:4 structured sparsity process.
- Modify the source model and replace the to-be-sparsified operator with the 2:4 structured sparsity operator.
Before performing this step, restore the already trained parameters, for example, ori_model.load() in 1.
1 2 3 4 5 6
simple_cfg = './retrain.cfg' record_file = './tmp/record.txt' prune_retrain_model = amct.create_prune_retrain_model(model=ori_model, input_data=ori_model_input_data, config_defination=simple_cfg, record_file=record_file)
- Implement gradient descent optimization on the modified graph and train the graph on the training dataset. (Update the sample code based on your situation.)
- Call save_prune_retrain_model to save the model, restore the replaced operator, and perform structured sparsification on the weight data, thus to generate the final ONNX fake-quantized model and deployable model.
1 2 3 4
prune_retrain_model = amct.save_prune_retrain_model( model=pruned_retrain_model, save_path=save_path, input_data=input_data)
- Modify the source model and replace the to-be-sparsified operator with the 2:4 structured sparsity operator.
- (Optional) Run inference on the quantized model (quant_model) in the ONNX Runtime environment based on the test dataset (test_data) to test the accuracy. Update the sample code based on your situation.
Compare the accuracy of the simulated model after sparsity with that of the original model in to observe the impact of 2:4 structured sparsity on the accuracy.
1 2
prune_retrain_model = './results/user_model_fake_prune_model.onnx' user_do_inference_onnx(prune_retrain_model, test_data, test_iterations)
If the training process is interrupted, restore data from the checkpoints to resume the training.
- Prepare a source model.
1ori_model= user_create_model()
- Call AMCT to resume the QAT process.
- Modify the model, replace the to-be-sparsified operator with the 2:4 structured sparsity operator, and save the model as a new prune_model.
1 2 3 4 5 6 7 8 9 10 11 12
model = ori_model input_data = ori_model_input_data record_file = './tmp/record.txt' config_defination = './prune_cfg.cfg' save_pth_path = /your/path/to/save/tmp.pth model.load_state_dict(torch.load(state_dict_path)) prune_retrain_model = amct.restore_prune_retrain_model(model=ori_model, input_data=ori_model_input_data, record_file=record_file, config_defination='./prune_cfg.cfg', save_pth_path=/your/path/to/save/tmp.pth, 'state_dict')
- Implement gradient descent optimization on the modified model, restore parameters from the checkpoints, and train the model on the training dataset. (Update the sample code based on your situation.)
- Restore the model parameters from the checkpoints after sparsity.
1 2
quant_pth = './ckpt/user_prune_model' user_train_model(optimizer, prune_retrain_model, train_data)
- Implement gradient descent optimization on the modified graph.
- Restore the model from existing checkpoints and train the model.
Note: Restore model parameters from the trained checkpoints before training.
1user_train_model(optimizer, prune_retrain_model, train_data)
- Restore the model parameters from the checkpoints after sparsity.
- Call save_prune_retrain_model to save the model, restore the replaced operator, and perform structured sparsification on the weight data, thus to generate the final ONNX fake-quantized model and deployable model.
1 2 3 4
prune_retrain_model = amct.save_prune_retrain_model( model=pruned_retrain_model, save_path=save_path, input_data=input_data)
- Modify the model, replace the to-be-sparsified operator with the 2:4 structured sparsity operator, and save the model as a new prune_model.
- (Optional) Run inference on the quantized model (quant_model) in the ONNX Runtime environment based on the test dataset (test_data) to test the accuracy. Update the sample code based on your situation.
Compare the accuracy of the simulated model after sparsity with that of the original model in to observe the impact of 2:4 structured sparsity on the accuracy.
1 2
prune_retrain_model = './results/user_model_fake_prune_model.onnx' user_do_inference_onnx(prune_retrain_model, test_data, test_iterations)
