实现函数化转换
“函数化转换”可以简单理解为将In-place算子替换为非In-place算子的过程,例如将torch.ops.aten.add_替换为torch.ops.aten.add。
PyTorch图模式基于函数化后的FX图工作,因此In-place类算子与PyTorch图模式配合工作时,需要实现函数化转换,实现将图上的In-place算子替换为非In-place算子。
换言之需要为In-place算子torch.ops.npu.my_inplace注册对应的非In-place算子用于替换,同时完成非In-place算子的Eager模式实现、Meta推导函数,以及In-place算子到非In-place算子的转换函数。
In-place类算子需要实现函数化,社区在PyTorch 2.5+版本提供了自动函数化转换能力,TorchAir将在未来版本支持该特性,目前仍需要您手动实现函数化。
函数化具体操作步骤如下:
注册非In-place算子
非In-place算子名要求:一般定义为“In-place算子名”+“_functional”后缀,同时由于非In-place算子将结果写入输出而非直接修改输入,因此In-place算子被修改的输入需要添加对应的输出。
在third_party/op-plugin/op_plugin/python/meta/_meta_registrations.py中,追加如下内容注册非In-place算子:
import torch
from torch.library import Library, impl
m_fragment = Library("npu", "FRAGMENT")
m_fragment.define("my_inplace_functional(Tensor x, Tensor y) -> (Tensor, Tensor)")
my_inplace_functional算子原型含义:包含两个输入x和y,输出两个新的Tensor。从逻辑上,第一个输出Tensor的值,与被In-place修改后的x一致,第二个则与被In-place修改后的y一致。
非In-place算子Eager模式NPU实现
my_inplace_functional的NPU实现正常执行时不会调用,但在图模式精度调试时非常重要,因此建议实现。
在third_party/op-plugin/op_plugin/python/meta/_meta_registrations.py中,追加如下内容支持Eager模式调用:
注意:NPU通过PrivateUse1设备扩展接入PyTorch,因此实现时的Dispatch key为PrivateUse1
import torch
from torch.library import Library, impl
@impl(m_fragment, "my_inplace_functional", "PrivateUse1")
def custom_add_npu(x, y):
x_clone = x.clone()
y_clone = y.clone()
torch.ops.npu.my_inplace(x_clone, y_clone)
return x_clone, y_clone
非In-place算子实现Meta推导
在third_party/op-plugin/op_plugin/python/meta/_meta_registrations.py中,追加如下内容实现Meta推导函数:
函数化是为my_inplace实现Functionalized的DispatchKey,my_inplace的计算逻辑不能变化,仍然需要原地修改输入x和y,而不是作为输出返回。
import torch
from torch.library import Library, impl
@impl(m, "my_inplace_functional")
def my_inplace_functional_meta(x, y):
return torch.empty_like(x), torch.empty_like(y)
In-place算子实现函数化转换
在third_party/op-plugin/op_plugin/python/meta/_meta_registrations.py中,使用my_inplace_functional(非In-place算子)替换原始的my_inplace(In-place算子),实现my_inplace算子的函数化。
import torch
from torch.library import Library, impl
@torch.library.impl(m_fragment, "my_inplace", "Functionalize")
def my_inplace_functional_npu(x, y):
x_out, y_out = torch.ops.npu.my_inplace_functional(x, y)
x.copy_(x_out)
y.copy_(y_out)