剪枝过程需要通过前向传播获取相关信息,默认模型输入是一个tensor。实际模型可能是多个变量,或一个dict,如下代码所示。
# 模型forward有多个输入 def forward(self, x, y, z): pass # 模型forward输入是一个dict def forward(self, {'x':x, 'y':y, 'z':z}): pass
这种情况下,需要通过torch.save(data, save_path)存下[x, y, z]列表或{'x':x, 'y':y, 'z':z}字典(batchsize为2),存为.pkl后缀的文件。该pickle文件地址记为pkl_data_path。