restore_compressed_retrain_model

Function Usage

Compresses the input model to be statically combined based on the specified combination compression configuration file and record file (sparsification before quantization), and loads the saved weight. Sparses the input model based on the sparsity records in the given record_file, and then inserts quantization-related operators (quantization aware training layer for data and weights and searchN layer) into the model. Loads the checkpoint weight parameters saved during training and returns the modified torch.nn.Module model.

Constraints

The compression combination configuration file must contain at least one of the following configurations: sparsity configuration or quantization configuration.

Prototype

compressed_retrain_model = restore_compressed_retrain_model(model, input_data, config_defination, record_file, pth_file, state_dict_name=None)

Command-Line Options

Option

Input/Return

Meaning

Restriction

model

Input

Torch model.

A torch.nn.Module.

input_data

Input

Input data of the model. A torch.tensor, equivalent to a tuple(torch.tensor).

A tuple.

config_defination

Input

Simplified configuration file for static compression combination.

That is, prune.cfg which is generated from the retrain_config_pytorch.proto file.

The retrain_config_pytorch.proto template is stored in /amct_pytorch/proto/retrain_config_pytorch.proto in the AMCT installation path.

For details about the parameters in the retrain_config_pytorch.proto file and the generated simplified quantization configuration file compressed.cfg, see Simplified QAT Configuration File.

A string

record_file

Input

File where sparsity and quantization factors have been recorded.

A string

pth_file

Input

Weight file saved during training.

A string

state_dict_name

Input

Key value corresponding to the weight in the weight file.

Default: None

A string

compressed_retrain_model

Returns

torch.nn.Module model that has been sparsified based on the sparsity relationship in record_file and has been inserted into the quantization-related layer, and has loaded the weight file.

A torch.nn.Module.

Returns

Models that have been loaded with the weight file for static compression combination training.

Outputs

None

Example

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
import amct_pytorch as amct
# Build a graph of the network for compression combination.
model = build_model()
input_data = tuple(torch.randn(input_shape))
save_pth_path = /your/path/to/save/tmp.pth

record_file = os.path.join(TMP, 'compressed_record.txt')
config_defination = './compressed_cfg.cfg'

torch.save({'state_dict': model.state_dict()}, save_pth_path)
compressed_retrain_model = amct.restore_compressed_retrain_model(
                                model,
                                input_data,
                                config_defination,
                                record_file,
                                save_pth_path,
                                'state_dict')