算子注册PyTorch
自定义算子入图前,需确保成功注册PyTorch框架,即已生成对应的Aten IR。具体实现步骤如下:
- 适配插件开发,实现PyTorch框架调用自定义算子。
- 为了能够正确入图,需要为自定义算子接口注册Meta函数,以帮助该API完成入图时的shape和data type推导,即算子在调用后输出的shape、dtype可根据算子的入参来推导。
1 2 3 4 5 6 7 8 9 10
import torch import torch_npu from torch.library import impl from torch_npu.meta._meta_registrations import m # 为npu_add_custom算子注册Meta @impl(m, "npu_add_custom", "Meta") # Meta函数要求入参与出参shape/dtype保持一致 def npu_add_custom_meta(x, y): return torch.empty_like(x)
父主题: 自定义算子插件化入图