样例代码说明
本文档提供的样例是基于PyTorch官网Imagenet数据集的训练模型代码脚本main.py,以PyTorch1.11.0为例。
因为当前昇腾适配的PyTorch版本没有torch.backends.mps这个模块,所以需要将原代码中所有mps模块相关代码注释后再进行迁移。具体如下:
- 原代码第147行:
if not torch.cuda.is_available(): # and not torch.backends.mps.is_available(): print('using CPU, this will be slow')
- 原代码第171行至173行:
# elif torch.backends.mps.is_available(): # device = torch.device("mps") # model = model.to(device)
- 原代码第187至188行:
# elif torch.backends.mps.is_available(): # device = torch.device("mps")
- 原代码第356行至358行:
# if torch.backends.mps.is_available(): # images = images.to('mps') # target = target.to('mps')
- 原代码第443至444行:
# elif torch.backends.mps.is_available(): # device = torch.device("mps")
父主题: 样例参考