昇腾社区首页
中文
注册

算子注册PyTorch

自定义算子入图前,需确保成功注册PyTorch框架,即已生成对应的ATen IR。具体实现步骤如下:

  1. 适配插件开发,实现PyTorch框架调用自定义算子。

    通过Ascend Extension for PyTorch中的OpPlugin算子插件实现算子注册分发(yaml文件中配置算子的定义等)和PyTorch适配插件实现即可

    详细的操作步骤请参考《PyTorch 框架特性指南》中的“自定义算子适配开发”章节

  2. 为了能够正确入FX图,需要为自定义算子注册Meta函数,通过PyTorch的Meta后端帮助算子完成入图时所需要的shape和data type推导,即算子在调用后输出的shape、dtype可根据算子的入参来推导。
    • Meta函数是PyTorch原生Fake Tensor的运行后端,自定义算子注册到PyTorch,必须借助Meta后端完成算子的InferShape,才能Dynamo Trace成FX图。关于Meta函数的更多介绍和定义,可参见PyTorch官网“The dynamic shapes manual”。
    • Meta函数必须在torch.compile执行前完成注册。
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    import torch
    import torch_npu
    from torch.library import impl
    from torch_npu.op_plugin.meta import _meta_registrations as m
    
    # 为npu_add_custom算子注册Meta
    @impl(m.m, "npu_add_custom", "Meta")   
    # Meta函数要求入参与出参shape/dtype保持一致            
    def npu_add_custom_meta(x, y):                   
        return torch.empty_like(x)