create_prune_retrain_model

Function Usage

Filter-level sparsity or 2:4 structured sparsity API. Only either of the two sparsity features can be enabled at a time.

  • Sparsifies a graph based on the sparsity configuration file input by the user. This function inserts the mask operator into the graph to achieve fake-sparsity for inference.
  • 2:4 structured sparsity: Sparsifies the graph structure based on the sparsity configuration file set by the user. This function inserts the 2:4 structured sparsity operator into the graph structure to achieve the pseudo-sparsity effect during inference.

Restrictions

None

Prototype

create_prune_retrain_model(graph, outputs, record_file, config_defination)

Command-Line Options

Option

Input/Return

Description

Restriction

graph

Input

A tf.Graph of the model to be pruned.

A tf.Graph.

outputs

Input

Model outputs.

A list of strings, for example, [output1,output2,...].

record_file

Input

Path including its name of the file that records sparsity information. This file records the cascade correlations between filter-level sparsity nodes or 2:4 structured sparsity nodes.

A string

config_defination

Input

Sparsity configuration file input by the user, which is used to specify the sparsity configuration of each layer in the tf.Graph model.

Whether to create a simplified configuration file quant.cfg from the retrain_config_tf.proto file in /amct_tensorflow/proto/retrain_config_tf.proto in the AMCT installation path.

For details about the parameters in the retrain_config_tf.proto file and the generated sparsity configuration file prune.cfg, see Simplified QAT Configuration File.

A string

Returns

None

Outputs

None

Examples

1
amct.create_prune_retrain_model(graph, [operation_name_1, operation_name_2], './tmp/record.txt', './tmp/sample_prune.cfg')