昇腾社区首页
中文
注册

Operation

针对用户调用单OP的场景,Operation提供创建OP、执行OP、OP信息调试等接口。可以直接通过该对象完成OP相关的操作。

创建Operation

接口名称

CLASS torch_atb.Operation(OpParam)

接口描述

创建单OP

Input要求/参数(可选)

OpParam

Output要求/参数(可选)

torch_atb.Operation

目前支持的入参类型有:LayerNormParam、ElewiseParam、LinearParam、SoftmaxParam、RopeParam、SelfAttentionParam、PagedAttentionParam、SplitParam、ReshapeAndCacheParam、GatherParam、ActivationParam、RmsNormParam。

name

功能

torch_atb.Operation的属性,用于获取torch_atb.Operation的名称。

原型

torch_atb.Operation.name -> str

参数

NA

返回值

Operation的名称。

注意

调用接口失败,会抛出异常终止运行。

使用示例

linear = torch_atb.Operation(linear_param)
print(linear.name)

input_num

功能

torch_atb.Operation的属性,用于获取输入tensor的个数。

原型

torch_atb.Operation.input_num ->int

参数

NA

返回值

输入tensor的个数。

注意

调用接口失败,会抛出异常终止运行。

使用示例

linear = torch_atb.Operation(linear_param)
print(linear.input_num)

output_num

功能

torch_atb.Operation的属性,用于获取输出tensor的个数。

原型

torch_atb.Operation.output_num ->int

参数

NA

返回值

输出tensor的个数。

注意

调用接口失败,会抛出异常终止运行。

使用示例

linear = torch_atb.Operation(linear_param)
print(linear.output_num)

forward

功能

torch_atb.Operation的方法,用于执行Operation。

原型

torch_atb.Operation.forward(List:torch::Tensor) -> List:torch::Tensor

参数

输入tensor,类型为List[torch.Tensor]。

返回值

输出tensor,类型为List[torch.Tensor]。

注意

调用接口失败,会抛出异常终止运行。

使用示例

input = torch.randn(m, k, dtype=torch.float16).npu()
weight = torch.randn(k, n, dtype=torch.float16).npu()
linear_param = torch_atb.LinearParam()
linear_param.has_bias = False
linear_param.transpose_b = False
linear = torch_atb.Operation(linear_param)
linear_outputs = linear.forward([input, weight])
print("linear_outputs : ", linear_outputs )