restore_prune_retrain_model
Function Usage
Filter-level sparsity or 2:4 structured sparsity API. Only either of the two sparsity features can be enabled at a time.
This API sparsifies the input graph based on the given sparsity record file record_file, and returns a result torch.nn.module model that can be used for retraining.
Prototype
prune_retrain_model = restore_prune_retrain_model (model, input_data, record_file, config_defination, pth_file, state_dict_name=None)
Command-Line Options
Parameter |
Input/Return |
Description |
Restriction |
|---|---|---|---|
model |
Input |
Model to prune, with weights loaded. |
A torch.nn.Module |
input_data |
Input |
Input data of the model. A torch.tensor, equivalent to a tuple(torch.tensor). |
A tuple. |
record_file |
Input |
Directory (including the file name) of the sparsity record file, which is generated by the create_prune_retrain_model API, to ensure that the models generated by the two APIs are consistent. |
A string |
config_defination |
Input |
Simplified configuration file. That is, prune.cfg which is generated from the retrain_config_pytorch.proto file. The retrain_config_pytorch.proto template is stored in /amct_pytorch/proto/retrain_config_pytorch.proto in the AMCT installation path. For details about the parameters in the retrain_config_pytorch.proto file and the generated simplified quantization configuration file prune.cfg, see Simplified QAT Configuration File. |
A string |
pth_file |
Input |
Weight file saved during training. |
A string |
state_dict_name |
Input |
Key value corresponding to the weight in the weight file. |
Default: None A string |
prune_retrain_model |
Return |
Result torch.nn.module model that can be used for retraining. |
Default: None A torch.nn.Module |
Returns
None
Outputs
Pruned model.
Examples
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 | import amct_pytorch as amct # Create a graph of the network to prune. config_defination = './prune_cfg.cfg' model = build_model() input_data = tuple([torch.randn(input_shape)]) save_pth_path = /your/path/to/save/tmp.pth model.load_state_dict(torch.load(state_dict_path)) # Call the API for sparsifying models. record_file = os.path.join(TMP, 'scale_offset_record.txt') prune_retrain_model = amct.restore_prune_retrain_model( model, input_data, record_file, config_defination, save_pth_path, 'state_dict') |