auto_decomposition
Function Usage
(Optional) Perform tensor decomposition on the input PyTorch model object to obtain the model object after decomposition and the names of the layers before and after decomposition, and save the decomposition information file.
Constraints
- The input model must be an object of the torch.nn.Module type.
- This API function decomposes only the convolution constructed by using torch.nn.Conv2d ().
- This API automatically decomposes the convolutional layers that meet the decomposition conditions. For details about the conditions, see Restrictions.
Prototype
model, changes = auto_decomposition(model, decompose_info_path=None)
Command-Line Options
Option |
Input/Return |
Description |
Restriction |
|---|---|---|---|
model |
Input |
PyTorch model object that contains pre-trained weights to be decomposed. When calling this API, you are advised to place the model on the CPU instead of the GPU to prevent insufficient GPU memory during decomposition. |
A torch.nn.Module. |
decompose_info_path |
Input (optional) |
Path for storing the decomposition information file. The file is stored in JSON format. Therefore, the .json file name extension is recommended. If the value is None, the decomposition information file is not saved (default). |
A string |
model |
Returns |
Model object after tensor decomposition. |
A torch.nn.Module. |
changes |
Return |
Dictionary consisting of the layer names before and after tensor decomposition, for example, {'conv1': ['conv1.0', 'conv1.1'], 'conv2': ['conv2.0', 'conv2.1'], ...}. |
A dict. |
Returns
Dictionary consisting of the model object after tensor decomposition and the corresponding layer names before and after tensor decomposition.
Outputs
decompose_info_path: decomposition information file, which is stored in JSON format.
Example
1 2 3 4 5 6 7 | from amct_pytorch.tensor_decompose import auto_decomposition net = Net() # (*) Build a model object. net.load_state_dict(torch.load("src_path/weights.pth")) # (*) Load model weights. Tensor decomposition of net, changes = auto_decomposition( # model=net, decompose_info_path="decomposed_path/decompose_info.json" ) |
- If training is involved, this API must be called before the model parameters are passed to the optimizer; if DDP is used, this API must be called before the model parameters are passed to the DDP.
- This API modifies the input model object in place. That is, the model object input by the user is changed after decomposition (exception: The input model is a torch.nn.Conv2d object. In this case, this API does not modify it. If decomposition occurs, the newly constructed torch.nn.Module object is returned.