SplitParam
属性  | 
类型  | 
默认值  | 
描述  | 
|---|---|---|---|
split_dim  | 
int  | 
0  | 
-  | 
split_num  | 
int  | 
2  | 
-  | 
split_sizes  | 
List[int]  | 
list()  | 
-  | 
约束说明
split_sizes的长度大小需小于DEFAULT_SVECTOR_SIZE,否则会抛出长度溢出异常。
调用示例
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18  | import torch import torch_atb def split(): input_npu = torch.randn(6, 6, dtype=torch.float16).npu() print("input: ", input_npu) split_param = torch_atb.SplitParam(split_dim = 0, split_num = 2) split = torch_atb.Operation(split_param) def split_run(): split_outputs = split.forward([input_npu]) return split_outputs outputs = split_run() print("outputs: ", outputs) if __name__ == "__main__": split()  | 
父主题: OpParam