create_prune_retrain_model

Function Usage

Filter-level sparsity or 2:4 structured sparsity API. Only either of the two sparsity features can be enabled at a time.

This API sparsifies the input graph based on the given sparsity configuration file, inserts or replaces related operators in the input graph, generates a sparsity record file record_file, and returns a result torch.nn.module model that can be used for retraining.

Prototype

prune_retrain_model = create_prune_retrain_model (model, input_data, config_defination, record_file)

Command-line options

Option

Input/Return

Meaning

Restriction

model

Input

Model to prune, with weights loaded.

A torch.nn.Module

input_data

Input

Input data of the model. A torch.tensor, equivalent to a tuple(torch.tensor).

A tuple.

config_defination

Input

Simplified configuration file.

That is, prune.cfg which is generated from the retrain_config_pytorch.proto file.

The retrain_config_pytorch.proto template is stored in /amct_pytorch/proto/retrain_config_pytorch.proto in the AMCT installation path.

For details about the parameters in the retrain_config_pytorch.proto file and the generated simplified quantization configuration file prune.cfg, see Simplified QAT Configuration File.

A string

record_file

Input

Path including its name of the file that records sparsity information. This file records the cascade correlations between filter-level sparsity nodes or 2:4 structured sparsity nodes.

A string

prune_retrain_model

Returns

Result torch.nn.module model that can be used for retraining.

Default: None

A torch.nn.Module

Returns

Pruned model.

Outputs

None

Example

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
import amct_pytorch as amct
# Create a graph of the network to prune.
model = build_model()
model.load_state_dict(torch.load(state_dict_path))
input_data = tuple([torch.randn(input_shape)])

# Call the API for sparsifying models.
record_file = os.path.join(TMP, 'scale_offset_record.txt')
cfg_file = './prune_config.cfg'
prune_retrain_model = amct.create_prune_retrain_model(
               model,
               input_data,
               cfg_file,
               record_file)