restore_compressed_retrain_model
Function Usage
Compresses the input model to be statically combined based on the specified combination compression configuration file and record file (sparsification before quantization), and loads the saved weight. Sparses the input model based on the sparsity records in the given record_file, and then inserts quantization-related operators (quantization aware training layer for data and weights and searchN layer) into the model. Loads the checkpoint weight parameters saved during training and returns the modified torch.nn.Module model.
Constraints
The compression combination configuration file must contain at least one of the following configurations: sparsity configuration or quantization configuration.
Prototype
compressed_retrain_model = restore_compressed_retrain_model(model, input_data, config_defination, record_file, pth_file, state_dict_name=None)
Command-Line Options
Option |
Input/Return |
Meaning |
Restriction |
|---|---|---|---|
model |
Input |
Torch model. |
A torch.nn.Module. |
input_data |
Input |
Input data of the model. A torch.tensor, equivalent to a tuple(torch.tensor). |
A tuple. |
config_defination |
Input |
Simplified configuration file for static compression combination. That is, prune.cfg which is generated from the retrain_config_pytorch.proto file. The retrain_config_pytorch.proto template is stored in /amct_pytorch/proto/retrain_config_pytorch.proto in the AMCT installation path. For details about the parameters in the retrain_config_pytorch.proto file and the generated simplified quantization configuration file compressed.cfg, see Simplified QAT Configuration File. |
A string |
record_file |
Input |
File where sparsity and quantization factors have been recorded. |
A string |
pth_file |
Input |
Weight file saved during training. |
A string |
state_dict_name |
Input |
Key value corresponding to the weight in the weight file. |
Default: None A string |
compressed_retrain_model |
Returns |
torch.nn.Module model that has been sparsified based on the sparsity relationship in record_file and has been inserted into the quantization-related layer, and has loaded the weight file. |
A torch.nn.Module. |
Returns
Models that have been loaded with the weight file for static compression combination training.
Outputs
None
Example
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 | import amct_pytorch as amct # Build a graph of the network for compression combination. model = build_model() input_data = tuple(torch.randn(input_shape)) save_pth_path = /your/path/to/save/tmp.pth record_file = os.path.join(TMP, 'compressed_record.txt') config_defination = './compressed_cfg.cfg' torch.save({'state_dict': model.state_dict()}, save_pth_path) compressed_retrain_model = amct.restore_compressed_retrain_model( model, input_data, config_defination, record_file, save_pth_path, 'state_dict') |