PyTorch部分接口返回tensor的连续性和stride值与CPU不一致
2025/05/28
15
问题信息
问题来源 | 产品大类 | 产品子类 | 关键字 |
---|---|---|---|
官方 | 模型训练 | PyTorch | -- |
问题现象描述
部分原生PyTorch接口在输入非连续张量的情况下,CPU输出为非连续张量,NPU输出为连续张量。
示例如下所示:
import torch import torch_npu input_cpu = torch.randn(3,2).T input_npu = torch.randn(3,2).npu().T cpu_out = torch.where(input_cpu>0, 0.0, 1.0) npu_out = torch.where(input_npu>0, 0.0, 1.0) print(cpu_out.is_contiguous()) print(npu_out.is_contiguous()) print(torch.equal(cpu_out, npu_out.cpu()))
通过打印可以发现,cpu_out为非连续张量,npu_out为连续张量,二者值相同。
原因分析
NPU设备当前对非连续计算不亲和,除少数算子外,大部分算子都只能对连续张量进行计算。出于设备亲和性考虑,NPU上非inplace/out类接口均被设计为输出连续张量。连续张量仅在“stride”上和非连续张量存在差异,“shape”和“value”均相同,一般不会对使用产生影响。
解决措施
在必须获取非连续张量的场景(如需要根据“stride”进行判断),可以参考如下方案进行处理:
- 将接口输入中的张量转为CPU张量,获取“stride”信息后再重新将结果转回NPU张量。
- 如果使用的接口有out类变体(如add->add.out),可以创建一个符合预期“shape”或“stride”的NPU张量,利用out类接口会保持连续性的特性进行计算。