Operation
针对用户调用单OP的场景,Operation提供创建OP、执行OP、OP信息调试等接口。可以直接通过该对象完成OP相关的操作。
创建Operation
接口名称 |
CLASS torch_atb.Operation(OpParam) |
---|---|
接口描述 |
创建单OP |
Input要求/参数(可选) |
OpParam |
Output要求/参数(可选) |
torch_atb.Operation |
目前支持的入参类型有:LayerNormParam、ElewiseParam、LinearParam、SoftmaxParam、RopeParam、SelfAttentionParam、PagedAttentionParam、SplitParam、ReshapeAndCacheParam、GatherParam、ActivationParam、RmsNormParam。
name
功能 |
torch_atb.Operation的属性,用于获取torch_atb.Operation的名称。 |
---|---|
原型 |
torch_atb.Operation.name -> str |
参数 |
NA |
返回值 |
Operation的名称。 |
注意 |
调用接口失败,会抛出异常终止运行。 |
使用示例:
linear = torch_atb.Operation(linear_param) print(linear.name)
input_num
功能 |
torch_atb.Operation的属性,用于获取输入tensor的个数。 |
---|---|
原型 |
torch_atb.Operation.input_num ->int |
参数 |
NA |
返回值 |
输入tensor的个数。 |
注意 |
调用接口失败,会抛出异常终止运行。 |
使用示例:
linear = torch_atb.Operation(linear_param) print(linear.input_num)
output_num
功能 |
torch_atb.Operation的属性,用于获取输出tensor的个数。 |
---|---|
原型 |
torch_atb.Operation.output_num ->int |
参数 |
NA |
返回值 |
输出tensor的个数。 |
注意 |
调用接口失败,会抛出异常终止运行。 |
使用示例:
linear = torch_atb.Operation(linear_param) print(linear.output_num)
forward
功能 |
torch_atb.Operation的方法,用于执行Operation。 |
---|---|
原型 |
torch_atb.Operation.forward(List:torch::Tensor) -> List:torch::Tensor |
参数 |
输入tensor,类型为List[torch.Tensor]。 |
返回值 |
输出tensor,类型为List[torch.Tensor]。 |
注意 |
调用接口失败,会抛出异常终止运行。 |
使用示例:
input = torch.randn(m, k, dtype=torch.float16).npu() weight = torch.randn(k, n, dtype=torch.float16).npu() linear_param = torch_atb.LinearParam() linear_param.has_bias = False linear_param.transpose_b = False linear = torch_atb.Operation(linear_param) linear_outputs = linear.forward([input, weight]) print("linear_outputs : ", linear_outputs )