昇腾社区首页
中文
注册

SplitParam

属性

类型

默认值

描述

split_dim

int

0

-

split_num

int

2

-

split_sizes

List[int]

list()

-

约束说明

split_sizes的长度大小需小于DEFAULT_SVECTOR_SIZE,否则会抛出长度溢出异常。

调用示例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
import torch
import torch_atb  

def split():
    input_npu = torch.randn(6, 6, dtype=torch.float16).npu()
    print("input: ", input_npu)
    split_param = torch_atb.SplitParam(split_dim = 0, split_num = 2)
    split = torch_atb.Operation(split_param)

    def split_run():
        split_outputs = split.forward([input_npu])
        return split_outputs

    outputs = split_run()
    print("outputs: ", outputs)

if __name__ == "__main__":
    split()