SplitParam
属性 |
类型 |
默认值 |
描述 |
---|---|---|---|
split_dim |
int |
0 |
- |
split_num |
int |
2 |
- |
split_sizes |
List[int] |
list() |
- |
约束说明
split_sizes的长度大小需小于DEFAULT_SVECTOR_SIZE,否则会抛出长度溢出异常。
调用示例
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch import torch_atb def split(): input_npu = torch.randn(6, 6, dtype=torch.float16).npu() print("input: ", input_npu) split_param = torch_atb.SplitParam(split_dim = 0, split_num = 2) split = torch_atb.Operation(split_param) def split_run(): split_outputs = split.forward([input_npu]) return split_outputs outputs = split_run() print("outputs: ", outputs) if __name__ == "__main__": split() |
父主题: OpParam