Operation
针对用户调用单OP的场景,Operation提供创建OP、执行OP、OP信息调试等接口。可以直接通过该对象完成OP相关的操作。
创建Operation
接口名称  | 
CLASS torch_atb.Operation(OpParam)  | 
|---|---|
接口描述  | 
创建单OP  | 
Input要求/参数(可选)  | 
OpParam  | 
Output要求/参数(可选)  | 
torch_atb.Operation  | 
目前支持的入参类型有:LayerNormParam、ElewiseParam、LinearParam、SoftmaxParam、RopeParam、SelfAttentionParam、PagedAttentionParam、SplitParam、ReshapeAndCacheParam、GatherParam、ActivationParam、RmsNormParam。
name
功能  | 
torch_atb.Operation的属性,用于获取torch_atb.Operation的名称。  | 
|---|---|
原型  | 
torch_atb.Operation.name -> str  | 
参数  | 
NA  | 
返回值  | 
Operation的名称。  | 
注意  | 
调用接口失败,会抛出异常终止运行。  | 
使用示例:
linear = torch_atb.Operation(linear_param) print(linear.name)
input_num
功能  | 
torch_atb.Operation的属性,用于获取输入tensor的个数。  | 
|---|---|
原型  | 
torch_atb.Operation.input_num ->int  | 
参数  | 
NA  | 
返回值  | 
输入tensor的个数。  | 
注意  | 
调用接口失败,会抛出异常终止运行。  | 
使用示例:
linear = torch_atb.Operation(linear_param) print(linear.input_num)
output_num
功能  | 
torch_atb.Operation的属性,用于获取输出tensor的个数。  | 
|---|---|
原型  | 
torch_atb.Operation.output_num ->int  | 
参数  | 
NA  | 
返回值  | 
输出tensor的个数。  | 
注意  | 
调用接口失败,会抛出异常终止运行。  | 
使用示例:
linear = torch_atb.Operation(linear_param) print(linear.output_num)
forward
功能  | 
torch_atb.Operation的方法,用于执行Operation。  | 
|---|---|
原型  | 
torch_atb.Operation.forward(List:torch::Tensor) -> List:torch::Tensor  | 
参数  | 
输入tensor,类型为List[torch.Tensor]。  | 
返回值  | 
输出tensor,类型为List[torch.Tensor]。  | 
注意  | 
调用接口失败,会抛出异常终止运行。  | 
使用示例:
input = torch.randn(m, k, dtype=torch.float16).npu()
weight = torch.randn(k, n, dtype=torch.float16).npu()
linear_param = torch_atb.LinearParam()
linear_param.has_bias = False
linear_param.transpose_b = False
linear = torch_atb.Operation(linear_param)
linear_outputs = linear.forward([input, weight])
print("linear_outputs : ", linear_outputs )