昇腾社区首页
中文
注册
开发者
下载

实现Meta推导函数

PyTorch原生要求所有能与torch.compile配合工作的算子需要实现Meta推导函数,又称为“符号化推导”。Meta函数表示了PyTorch算子输出与输入shape、dtype以及内存的关系,它是PyTorch入图的前提条件,借助符号化和符号guard可静态化控制流和形状信息,从而确定图结构。关于Meta函数的详细介绍请参考PyTorch官网符号化手册

  • Meta推导函数必须在torch.compile执行前完成注册。
  • torch.library.Library接口介绍请参考PyTorch官网

进入third_party/op-plugin/op_plugin/python/meta/_meta_registrations.py实现Meta推导函数:

import torch
from torch.library import Library, impl
# meta register implementation
m = Library("npu", "IMPL", "Meta")

@impl(m, "my_op")
def my_op_meta(x, y, z, attr1, attr2):
    return torch.empty_like(x)
  • my_op_meta:Meta函数名,通常以PyTorch算子名+"_meta"后缀命名。
  • m:表示NPU算子的Meta实现库,通常定义在文件开头“m=Library("npu", "IMPL", "Meta")”。