昇腾社区首页
中文
注册
开发者
下载

实现函数化转换

“函数化转换”可以简单理解为将In-place算子替换为非In-place算子的过程,例如将torch.ops.aten.add_替换为torch.ops.aten.add。

PyTorch图模式基于函数化后的FX图工作,因此In-place类算子与PyTorch图模式配合工作时,需要实现函数化转换,实现将图上的In-place算子替换为非In-place算子。

换言之需要为In-place算子torch.ops.npu.my_inplace注册对应的非In-place算子用于替换,同时完成非In-place算子的Eager模式实现、Meta推导函数,以及In-place算子到非In-place算子的转换函数。

In-place类算子需要实现函数化,社区在PyTorch 2.5+版本提供了自动函数化转换能力,TorchAir将在未来版本支持该特性,目前仍需要您手动实现函数化。

函数化具体操作步骤如下:

  1. 注册非In-place算子
  2. 非In-place算子Eager模式NPU实现
  3. 非In-place算子实现Meta推导
  4. In-place算子实现函数化转换

注册非In-place算子

非In-place算子名要求:一般定义为“In-place算子名”+“_functional”后缀,同时由于非In-place算子将结果写入输出而非直接修改输入,因此In-place算子被修改的输入需要添加对应的输出。

在third_party/op-plugin/op_plugin/python/meta/_meta_registrations.py中,追加如下内容注册非In-place算子:

import torch
from torch.library import Library, impl

m_fragment = Library("npu", "FRAGMENT")
m_fragment.define("my_inplace_functional(Tensor x, Tensor y) -> (Tensor, Tensor)")

my_inplace_functional算子原型含义:包含两个输入x和y,输出两个新的Tensor。从逻辑上,第一个输出Tensor的值,与被In-place修改后的x一致,第二个则与被In-place修改后的y一致。

非In-place算子Eager模式NPU实现

my_inplace_functional的NPU实现正常执行时不会调用,但在图模式精度调试时非常重要,因此建议实现。

在third_party/op-plugin/op_plugin/python/meta/_meta_registrations.py中,追加如下内容支持Eager模式调用:

注意:NPU通过PrivateUse1设备扩展接入PyTorch,因此实现时的Dispatch key为PrivateUse1

import torch
from torch.library import Library, impl

@impl(m_fragment, "my_inplace_functional", "PrivateUse1")
def custom_add_npu(x, y):
    x_clone = x.clone()
    y_clone = y.clone()
    torch.ops.npu.my_inplace(x_clone, y_clone)
    return x_clone, y_clone

非In-place算子实现Meta推导

在third_party/op-plugin/op_plugin/python/meta/_meta_registrations.py中,追加如下内容实现Meta推导函数:

函数化是为my_inplace实现Functionalized的DispatchKey,my_inplace的计算逻辑不能变化,仍然需要原地修改输入x和y,而不是作为输出返回。

import torch
from torch.library import Library, impl

@impl(m, "my_inplace_functional")
def my_inplace_functional_meta(x, y):
    return torch.empty_like(x), torch.empty_like(y)

In-place算子实现函数化转换

在third_party/op-plugin/op_plugin/python/meta/_meta_registrations.py中,使用my_inplace_functional(非In-place算子)替换原始的my_inplace(In-place算子),实现my_inplace算子的函数化。

import torch
from torch.library import Library, impl

@torch.library.impl(m_fragment, "my_inplace", "Functionalize")
def my_inplace_functional_npu(x, y):
    x_out, y_out = torch.ops.npu.my_inplace_functional(x, y)
    x.copy_(x_out)
    y.copy_(y_out)