restore_prune_retrain_model

Function Usage

Filter-level sparsity or 2:4 structured sparsity API. Only either of the two sparsity features can be enabled at a time.

This API sparsifies the input graph based on the given sparsity record file record_file, and returns a result torch.nn.module model that can be used for retraining.

Prototype

prune_retrain_model = restore_prune_retrain_model (model, input_data, record_file, config_defination, pth_file, state_dict_name=None)

Command-Line Options

Parameter

Input/Return

Description

Restriction

model

Input

Model to prune, with weights loaded.

A torch.nn.Module

input_data

Input

Input data of the model. A torch.tensor, equivalent to a tuple(torch.tensor).

A tuple.

record_file

Input

Directory (including the file name) of the sparsity record file, which is generated by the create_prune_retrain_model API, to ensure that the models generated by the two APIs are consistent.

A string

config_defination

Input

Simplified configuration file.

That is, prune.cfg which is generated from the retrain_config_pytorch.proto file.

The retrain_config_pytorch.proto template is stored in /amct_pytorch/proto/retrain_config_pytorch.proto in the AMCT installation path.

For details about the parameters in the retrain_config_pytorch.proto file and the generated simplified quantization configuration file prune.cfg, see Simplified QAT Configuration File.

A string

pth_file

Input

Weight file saved during training.

A string

state_dict_name

Input

Key value corresponding to the weight in the weight file.

Default: None

A string

prune_retrain_model

Return

Result torch.nn.module model that can be used for retraining.

Default: None

A torch.nn.Module

Returns

None

Outputs

Pruned model.

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
import amct_pytorch as amct
# Create a graph of the network to prune.
config_defination = './prune_cfg.cfg'
model = build_model()
input_data = tuple([torch.randn(input_shape)])
save_pth_path = /your/path/to/save/tmp.pth
model.load_state_dict(torch.load(state_dict_path))

# Call the API for sparsifying models.
record_file = os.path.join(TMP, 'scale_offset_record.txt')
prune_retrain_model = amct.restore_prune_retrain_model(           
               model,
               input_data,
               record_file,
               config_defination,
               save_pth_path,
              'state_dict')