create_prune_retrain_model
Applicability
Product |
Supported |
|---|---|
|
|
|
|
|
|
|
|
|
Note: For the Products marked with x, no error is reported when the API is called for the 2:4 structured sparsity feature, but the performance benefits cannot be obtained.
Description
Filter-level sparsity or 2:4 structured sparsity API. Only either of the two sparsity features can be enabled at a time.
- Filter-level sparsity: Sparsifies a graph based on the sparsity configuration file input by the user. This function inserts the filter-level sparsity mask operator into the graph to achieve fake-sparsity for inference.
- 2:4 structured sparsity: Sparsifies a graph based on the sparsity configuration file input by the user. This function inserts the 2:4 structured sparsity operator into the graph to achieve fake-sparsity for inference.
Prototype
1 | create_prune_retrain_model(graph, outputs, record_file, config_defination) |
Parameters
Parameter |
Input/Output |
Description |
|---|---|---|
graph |
Input |
tf.Graph of the model for sparsity. A tf.Graph. |
outputs |
Input |
Output of the user model A list of strings, for example, [output1,output2,...]. |
record_file |
Input |
Path (including the file name) of the file that records sparsity information. This file records the cascade correlations between filter-level sparsity nodes or 2:4 structured sparsity nodes. A string. |
config_defination |
Input |
Sparsity configuration file input by the user, which specifies the sparsity configuration of each layer in the tf.Graph. The simplified configuration file prune.cfg is generated based on the retrain_config_tf.proto file. The *.proto file is stored in /amct_tensorflow/proto/ under the AMCT installation directory. For details about the parameters in the *.proto file and the generated sparsity configuration file prune.cfg, see Simplified QAT Configuration File. A string. |
Returns
None
Example
1 | amct.create_prune_retrain_model(graph, [operation_name_1, operation_name_2], './tmp/record.txt', './tmp/sample_prune.cfg') |