decompose_network
Function Usage
Decomposes the input PyTorch model object based on the decomposition information file saved via the auto_decomposition call and returns the decomposed model object and the layer names before and after decomposition.
Constraints
- The input model must be an object of the torch.nn.Module type.
- This API function can be used to modify only the structure of the convolution constructed by using torch.nn.Conv2d ().
- The structure of the input model must be the consistent with that in the generation of the decomposition information file by calling auto_decomposition.
Prototype
model, changes = decompose_network(model, decompose_info_path)
Command-Line Options
Option |
Input/Return |
Description |
Restriction |
|---|---|---|---|
model |
Input |
PyTorch model object to be decomposed. When calling this API, you are advised to place the model on the CPU instead of the GPU to prevent insufficient GPU memory during decomposition. |
A torch.nn.Module. |
decompose_info_path |
Input |
Path of the decomposition information file, which is obtained through the auto_decomposition call. |
A string |
model |
Return |
Model object that is changed to the structure after tensor decomposition. |
A torch.nn.Module. |
changes |
Return |
Dictionary consisting of the layer names before and after tensor decomposition, for example, {'conv1': ['conv1.0', 'conv1.1'], 'conv2': ['conv2.0', 'conv2.1'], ...}. |
A dict. |
Returns
Dictionary consisting of the model object after tensor decomposition and the corresponding layer names before and after tensor decomposition.
Outputs
None
Examples
1 2 3 4 5 6 | from amct_pytorch.tensor_decompose import decompose_network net = Net() # (*) Build a model object. net, changes = decompose_network( # Load the decomposition information file to modify the model structure. model=net, decompose_info_path="decomposed_path/decompose_info.json" #: path of the decomposition information file stored by auto_decomposition ) |
- If training is involved, this API must be called before the model parameters are passed to the optimizer; if DDP is used, this API must be called before the model parameters are passed to the DDP.
- This API modifies the input model object in place. That is, the model object input by the user is changed after decomposition (exception: The input model is a torch.nn.Conv2d object. In this case, this API does not modify it. The returned decomposed model is the newly constructed torch.nn.Module object.
- This API only modifies the model structure and does not update the convolution weight after decomposition. The weight value is the default value built by torch.nn.Conv2d (). If fine-tuning is required, save the weight of the decomposed model after calling auto_decomposition, load the weight after calling this API, and then fine-tune the model.