create_compressed_retrain_model
Function Usage
Applies to static compression combination. Compresses the input model based on the specified static compression combination configuration file. That is, prunes the input model (via either filter-level sparsity or 2:4 structured sparsity), inserts quantization operators (QAT layer for activations and weights and searchN layer) into the model, generates the sparsity record file record_file (if the configuration exists), and returns the modified torch.nn.Module model.
Constraints
The compression combination configuration file must contain at least one of the following configurations: sparsity configuration or quantization configuration.
Prototype
compressed_retrain_model = create_compressed_retrain_model(model, input_data, config_defination, record_file)
Command-Line Options
Option |
Input/Return |
Description |
Restriction |
|---|---|---|---|
model |
Input |
PyTorch model. |
A torch.nn.Module. |
input_data |
Input |
Input data of the model. |
A tuple. |
config_defination |
Input |
Simplified configuration file for static compression combination. That is, prune.cfg which is generated from the retrain_config_pytorch.proto file. The retrain_config_pytorch.proto template is stored in /amct_pytorch/proto/retrain_config_pytorch.proto in the AMCT installation path. For details about the parameters in the retrain_config_pytorch.proto file and the generated simplified configuration file compressed.cfg, see Simplified QAT Configuration File. |
A string |
record_file |
Input |
Path and name of the sparsity and quantization factor file to be recorded. |
A string |
compressed_retrain_model |
Returns |
Sparsifies the data based on the configuration file (if configured) and inserts torch.nn.Module of the quantization-related layer (if configured). |
A torch.nn.Module. |
Return Value
Model for static compression combination.
Outputs
Static combination compression record file record_file. If the simplified configuration file contains sparsity configuration, record_file contains sparsity record information after the function is executed.
Examples
1 2 3 4 5 6 7 8 9 10 11 12 13 14 | import amct_pytorch as amct # Set up a network for static compression combination. model = build_model() input_data = tuple([torch.randn(input_shape)]) # Call the static compression combination API. record_file = os.path.join(TMP, 'compressed_record.txt') config_defination = './compressed_cfg.cfg' compressed_retrain_model = amct.create_compressed_retrain_model( model, input_data, config_defination, record_file) |