decompose_network

产品支持情况

产品

是否支持

Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件

Atlas 200I/500 A2 推理产品

Atlas 推理系列产品

Atlas 训练系列产品

功能说明

用户输入PyTorch模型对象和通过auto_decomposition保存的分解信息文件,根据分解信息文件将模型对象改变为张量分解后的结构,得到分解后的模型对象和分解前后层的对应名称。

函数原型

1
model, changes = decompose_network(model, decompose_info_path)

参数说明

参数名

输入/输出

说明

model

输入

含义:待分解的PyTorch模型对象。在调用该接口时建议将模型放置于CPU而不是GPU上,以防分解时显存不足。

数据类型:torch.nn.Module

decompose_info_path

输入

含义:分解信息文件路径,该文件通过auto_decomposition获得。

数据类型:string

返回值说明

约束说明

调用示例

1
2
3
4
5
6
from amct_pytorch.tensor_decompose import decompose_network
net = Net()                                                      # 构建用户模型对象
net, changes = decompose_network(                                # 加载分解信息文件,将模型结构修改为张量分解后的结构
    model=net,
    decompose_info_path="decomposed_path/decompose_info.json"    # 由auto_decomposition保存的分解信息文件路径
)
  1. 当涉及模型训练时,本接口的调用需在将模型参数传递给优化器之前;如使用了torch.nn.parallel.DistributedDataParallel (DDP),则本接口的调用也需在将模型传递给DDP之前。
  2. 本接口将原地修改传入的模型对象,即分解后会改变用户传入的模型对象本身(例外:传入的模型是一个torch.nn.Conv2d对象,该情况下本接口不会对其进行修改,返回的分解后模型是新构建的torch.nn.Module对象)。
  3. 本接口仅对模型结构进行修改,不会更新分解后的卷积权重,权重的值为torch.nn.Conv2d()构建的默认值。如需finetune,请在调用auto_decomposition后将分解后的模型权重保存下来,在调用本接口之后加载该权重,再进行finetune。