实现Meta推导函数
PyTorch原生要求所有能与torch.compile配合工作的算子需要实现Meta推导函数,又称为“符号化推导”。Meta函数表示了PyTorch算子输出与输入shape、dtype以及内存的关系,它是PyTorch入图的前提条件,借助符号化和符号guard可静态化控制流和形状信息,从而确定图结构。关于Meta函数的详细介绍请参考PyTorch官网符号化手册。
进入third_party/op-plugin/op_plugin/python/meta/_meta_registrations.py实现Meta推导函数:
import torch
from torch.library import Library, impl
# meta register implementation
m = Library("npu", "IMPL", "Meta")
@impl(m, "my_op")
def my_op_meta(x, y, z, attr1, attr2):
return torch.empty_like(x)
- my_op_meta:Meta函数名,通常以PyTorch算子名+"_meta"后缀命名。
- m:表示NPU算子的Meta实现库,通常定义在文件开头“m=Library("npu", "IMPL", "Meta")”。
父主题: 非In-place算子开发和入图样例
