昇腾社区首页
EN
注册

算子注册PyTorch

自定义算子入图前,需确保成功注册PyTorch框架,即已生成对应的Aten IR。具体实现步骤如下:

  1. 适配插件开发,实现PyTorch框架调用自定义算子。

    通过Ascend Extension for PyTorch中的OP-Plugin算子插件实现算子注册分发(yaml文件中配置算子的定义等)和PyTorch适配插件实现即可

    详细的操作步骤请参考《套件与三方库支持清单》中的“单算子适配OpPlugin插件开发”章节

  2. 为了能够正确入图,需要为自定义算子接口注册Meta函数,以帮助该API完成入图时的shape和data type推导,即算子在调用后输出的shape、dtype可根据算子的入参来推导。
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    import torch
    import torch_npu
    from torch.library import impl
    from torch_npu.meta._meta_registrations import m
    
    # 为npu_add_custom算子注册Meta
    @impl(m, "npu_add_custom", "Meta")   
    # Meta函数要求入参与出参shape/dtype保持一致            
    def npu_add_custom_meta(x, y):                   
        return torch.empty_like(x)