create_prune_retrain_model
Function Usage
Filter-level sparsity or 2:4 structured sparsity API. Only either of the two sparsity features can be enabled at a time.
This API sparsifies the input graph based on the given sparsity configuration file, inserts or replaces related operators in the input graph, generates a sparsity record file record_file, and returns a result torch.nn.module model that can be used for retraining.
Prototype
prune_retrain_model = create_prune_retrain_model (model, input_data, config_defination, record_file)
Command-line options
Option |
Input/Return |
Meaning |
Restriction |
|---|---|---|---|
model |
Input |
Model to prune, with weights loaded. |
A torch.nn.Module |
input_data |
Input |
Input data of the model. A torch.tensor, equivalent to a tuple(torch.tensor). |
A tuple. |
config_defination |
Input |
Simplified configuration file. That is, prune.cfg which is generated from the retrain_config_pytorch.proto file. The retrain_config_pytorch.proto template is stored in /amct_pytorch/proto/retrain_config_pytorch.proto in the AMCT installation path. For details about the parameters in the retrain_config_pytorch.proto file and the generated simplified quantization configuration file prune.cfg, see Simplified QAT Configuration File. |
A string |
record_file |
Input |
Path including its name of the file that records sparsity information. This file records the cascade correlations between filter-level sparsity nodes or 2:4 structured sparsity nodes. |
A string |
prune_retrain_model |
Returns |
Result torch.nn.module model that can be used for retraining. |
Default: None A torch.nn.Module |
Returns
Pruned model.
Outputs
None
Example
1 2 3 4 5 6 7 8 9 10 11 12 13 14 | import amct_pytorch as amct # Create a graph of the network to prune. model = build_model() model.load_state_dict(torch.load(state_dict_path)) input_data = tuple([torch.randn(input_shape)]) # Call the API for sparsifying models. record_file = os.path.join(TMP, 'scale_offset_record.txt') cfg_file = './prune_config.cfg' prune_retrain_model = amct.create_prune_retrain_model( model, input_data, cfg_file, record_file) |