decompose_network
功能说明
用户输入PyTorch模型对象和通过auto_decomposition保存的分解信息文件,根据分解信息文件将模型对象改变为张量分解后的结构,得到分解后的模型对象和分解前后层的对应名称。
约束说明
- 用户输入的模型需为torch.nn.Module类型的对象。
- 本接口函数仅支持对通过torch.nn.Conv2d()构建的卷积的结构修改。
- 用户输入的模型结构需与调用auto_decomposition获取分解信息文件时的模型结构一致,分解信息文件要与该模型结构配套使用。
函数原型
model, changes = decompose_network(model, decompose_info_path)
参数说明
参数名 |
输入/返回值 |
含义 |
使用限制 |
---|---|---|---|
model |
输入 |
待分解的PyTorch模型对象。在调用该接口时建议将模型放置于CPU而不是GPU上,以防分解时显存不足。 |
数据类型:torch.nn.Module |
decompose_info_path |
输入 |
分解信息文件路径,该文件通过auto_decomposition获得。 |
数据类型:string |
model |
返回值 |
改变为张量分解后结构的模型对象。 |
数据类型:torch.nn.Module |
changes |
返回值 |
张量分解前后的对应层名构成的字典,形如{'conv1': ['conv1.0', 'conv1.1'], 'conv2': ['conv2.0', 'conv2.1'], ...}。 |
数据类型:dict |
返回值说明
改变为张量分解后结构的模型对象、张量分解前后的对应层名构成的字典。
函数输出
无。
调用示例
1 2 3 4 5 6 |
from amct_pytorch.tensor_decompose import decompose_network net = Net() # 构建用户模型对象 net, changes = decompose_network( # 加载分解信息文件,将模型结构修改为张量分解后的结构 model=net, decompose_info_path="decomposed_path/decompose_info.json" # 由auto_decomposition保存的分解信息文件路径 ) |

- 当涉及模型训练时,本接口的调用需在将模型参数传递给优化器之前;如使用了torch.nn.parallel.DistributedDataParallel (DDP),则本接口的调用也需在将模型传递给DDP之前。
- 本接口将原地修改传入的模型对象,即分解后会改变用户传入的模型对象本身(例外:传入的模型是一个torch.nn.Conv2d对象,该情况下本接口不会对其进行修改,返回的分解后模型是新构建的torch.nn.Module对象)。
- 本接口仅对模型结构进行修改,不会更新分解后的卷积权重,权重的值为torch.nn.Conv2d()构建的默认值。如需finetune,请在调用auto_decomposition后将分解后的模型权重保存下来,在调用本接口之后加载该权重,再进行finetune。
父主题: 张量分解接口