create_compressed_retrain_model

Function Usage

Applies to static compression combination. Compresses the input model based on the specified static compression combination configuration file. That is, prunes the input model (via either filter-level sparsity or 2:4 structured sparsity), inserts quantization operators (QAT layer for activations and weights and searchN layer) into the model, generates the sparsity record file record_file (if the configuration exists), and returns the modified torch.nn.Module model.

Constraints

The compression combination configuration file must contain at least one of the following configurations: sparsity configuration or quantization configuration.

Prototype

compressed_retrain_model = create_compressed_retrain_model(model, input_data, config_defination, record_file)

Command-Line Options

Option

Input/Return

Description

Restriction

model

Input

PyTorch model.

A torch.nn.Module.

input_data

Input

Input data of the model.

A tuple.

config_defination

Input

Simplified configuration file for static compression combination.

That is, prune.cfg which is generated from the retrain_config_pytorch.proto file.

The retrain_config_pytorch.proto template is stored in /amct_pytorch/proto/retrain_config_pytorch.proto in the AMCT installation path.

For details about the parameters in the retrain_config_pytorch.proto file and the generated simplified configuration file compressed.cfg, see Simplified QAT Configuration File.

A string

record_file

Input

Path and name of the sparsity and quantization factor file to be recorded.

A string

compressed_retrain_model

Returns

Sparsifies the data based on the configuration file (if configured) and inserts torch.nn.Module of the quantization-related layer (if configured).

A torch.nn.Module.

Return Value

Model for static compression combination.

Outputs

Static combination compression record file record_file. If the simplified configuration file contains sparsity configuration, record_file contains sparsity record information after the function is executed.

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
import amct_pytorch as amct
# Set up a network for static compression combination.
model = build_model()
input_data = tuple([torch.randn(input_shape)])

# Call the static compression combination API.
record_file = os.path.join(TMP, 'compressed_record.txt')
config_defination = './compressed_cfg.cfg'

compressed_retrain_model = amct.create_compressed_retrain_model(
                                model,
                                input_data,
                                config_defination,
                                record_file)