create_prune_retrain_model
Function Usage
Filter-level sparsity or 2:4 structured sparsity API. Only either of the two sparsity features can be enabled at a time.
- Sparsifies a graph based on the sparsity configuration file input by the user. This function inserts the mask operator into the graph to achieve fake-sparsity for inference.
- 2:4 structured sparsity: Sparsifies the graph structure based on the sparsity configuration file set by the user. This function inserts the 2:4 structured sparsity operator into the graph structure to achieve the pseudo-sparsity effect during inference.
Restrictions
None
Prototype
create_prune_retrain_model(graph, outputs, record_file, config_defination)
Command-Line Options
Option |
Input/Return |
Description |
Restriction |
|---|---|---|---|
graph |
Input |
A tf.Graph of the model to be pruned. |
A tf.Graph. |
outputs |
Input |
Model outputs. |
A list of strings, for example, [output1,output2,...]. |
record_file |
Input |
Path including its name of the file that records sparsity information. This file records the cascade correlations between filter-level sparsity nodes or 2:4 structured sparsity nodes. |
A string |
config_defination |
Input |
Sparsity configuration file input by the user, which is used to specify the sparsity configuration of each layer in the tf.Graph model. Whether to create a simplified configuration file quant.cfg from the retrain_config_tf.proto file in /amct_tensorflow/proto/retrain_config_tf.proto in the AMCT installation path. For details about the parameters in the retrain_config_tf.proto file and the generated sparsity configuration file prune.cfg, see Simplified QAT Configuration File. |
A string |
Returns
None
Outputs
None
Examples
1 | amct.create_prune_retrain_model(graph, [operation_name_1, operation_name_2], './tmp/record.txt', './tmp/sample_prune.cfg') |