样例代码说明
本文档提供的样例是基于PyTorch官网的Imagenet数据集训练模型脚本代码main.py,以PyTorch1.8.1为例。
因为当前适配的昇腾PyTorch版本没有torch.backends.mps这个模块,所以需要将原代码中所有mps模块相关代码注释掉后再进行迁移。具体如下:
- 原代码第147行:
if not torch.cuda.is_available(): #and not torch.backends.mps.is_available(): print('using CPU, this will be slow')
- 原代码第171行至173行:
#elif torch.backends.mps.is_available(): #device = torch.device("mps") #model = model.to(device)
- 原代码第187至188行:
#elif torch.backends.mps.is_available(): #device = torch.device("mps")
- 原代码第356行至358行:
#if torch.backends.mps.is_available(): #images = images.to('mps') #target = target.to('mps')
- 原代码第443至444行:
#elif torch.backends.mps.is_available(): #device = torch.device("mps")
父主题: 样例参考