SplitOperation

功能

将输入张量沿指定维度切分成多个张量。

算子上下文

图1 SplitOperation

算子功能实现

定义

1
2
3
4
5
6
struct SplitParam {
    int32_t splitDim = 0;
    int32_t splitNum = 2;
    SVector<int32_t> splitSizes = {};
    uint8_t rsv[8] = {0};
};

参数列表

成员名称

类型

默认值

描述

splitDim

int32_t

0

指定切分的维度索引。

splitDim须位于输入张量x的维度范围内,即如果x的维度为xDim,则splitDim的取值范围为[-xDim, xDim - 1]。 当splitDim为负数时,其含义是从最高维度开始访问,如splitDim = -1,x维度数为dimNum,则拆分维度为dimNum - 1。

splitNum

int32_t

2

等分次数,当前支持2或3。

输入张量x的维度须能够被splitNum整除,且当splitNum = 3时输入x要求是float16或者bf16数据类型。

splitSizes

SVector<int32_t>

-

指定每个输出tensor在切分维度上的大小,不传入此参数时使用等长切分,传入此参数时使用splitV不等长切分。

rsv[8]

uint8_t

{0}

预留参数。

splitNum=2时输入输出

参数

维度

数据类型

格式

描述

x

[dim_0, ..., dim_splitDim, ..., dim_n]

float16/int64/bf16

ND

输入,最高支持8维。

output1

  • [dim_0, …, dim_splitDim_1, ..., dim_n]
  • 当splitSizes为空时,dim_splitDim_1=dim_splitDim/2;
  • 当splitSizes不为空时,dim_splitDim_2=splitSizes[0]。

float16/int64/bf16

ND

输出,切分后的tensor。数据类型与x一致。

output2

  • [dim_0, …, dim_splitDim_2, ..., dim_n]
  • 当splitSizes为空时,dim_splitDim_2=dim_splitDim/2;
  • 当splitSizes不为空时,dim_splitDim_2=splitSizes[1]。

float16/int64/bf16

ND

输出,切分后的tensor。数据类型与x一致。

splitNum=3时输入输出

参数

维度

数据类型

格式

描述

x

[dim_0, ..., dim_splitDim,...,dim_n]

float16/bf16

ND

输入,最高支持8维。

output1

  • [dim_0, …, dim_splitDim_1, ..., dim_n]
  • 当splitSizes为空时,dim_splitDim_1=dim_splitDim/2;
  • 当splitSizes不为空时,dim_splitDim_1=splitSizes[0]。

float16/bf16

ND

输出,切分后的tensor。数据类型与x一致。

output2

  • [dim_0, …, dim_splitDim_2, ..., dim_n]
  • 当splitSizes为空时,dim_splitDim_2=dim_splitDim/3;
  • 当splitSizes不为空时,dim_splitDim_2=splitSizes[1]。

float16/bf16

ND

输出,切分后的tensor。数据类型与x一致。

output3

  • [dim_0, …, dim_splitDim_3, ..., dim_n]
  • 当splitSizes为空时,dim_splitDim_3=dim_splitDim/3;
  • 当splitSizes不为空时,dim_splitDim_3=splitSizes[2]。

float16/bf16

ND

输出,切分后的tensor。数据类型与x一致。

规格约束

接口调用示例

输入:

splitDim = 0
splitNum = 3
x = [3, 3, 3, 3, 3, 3, 3, 3, 3]

输出:

z = [3, 3, 3]
z1 = [3, 3, 3]
z2 = [3, 3, 3]

算子调用示例(C++)

前置条件和编译命令请参见算子调用示例

场景:基础场景。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#include <iostream>
#include <vector>
#include <numeric>
#include <random>
#include "acl/acl.h"
#include "atb/operation.h"
#include "atb/types.h"
#include "atb/atb_infer.h"

#include "demo_util.h"

int main(int argc, char **argv)
{
    // 设置卡号、创建context、设置stream
    CHECK_STATUS(aclInit(nullptr));
    int32_t deviceId = 0;
    CHECK_STATUS(aclrtSetDevice(deviceId));
    atb::Context *context = nullptr;
    CHECK_STATUS(atb::CreateContext(&context));
    void *stream = nullptr;
    CHECK_STATUS(aclrtCreateStream(&stream));
    context->SetExecuteStream(stream);

    // 配置Op参数
    atb::infer::SplitParam opParam;
    opParam.splitDim = 1;         // 设定切分轴为1
    opParam.splitNum = 2;         // 设置切分后得到的块数
    opParam.splitSizes = {2, 3};  // 设置不均匀切分时每块大小

    // 准备VariantPack
    atb::VariantPack variantPack;
    std::vector<int64_t> inputXShape = {1, 5, 2};
    std::vector<int64_t> output1Shape = {1, 2, 2};
    std::vector<int64_t> output2Shape = {1, 3, 2};

    std::vector<float> inTensorXData = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
    std::vector<float> output1RefData = {0, 1, 2, 3};
    std::vector<float> output2RefData = {4, 5, 6, 7, 8, 9};

    atb::Tensor inTensorX =
        CreateTensorFromVector(context, stream, inTensorXData, ACL_FLOAT16, aclFormat::ACL_FORMAT_ND, inputXShape);
    atb::Tensor outTensor1 = CreateTensor(ACL_FLOAT16, aclFormat::ACL_FORMAT_ND, output1Shape);
    atb::Tensor outTensor2 = CreateTensor(ACL_FLOAT16, aclFormat::ACL_FORMAT_ND, output2Shape);

    variantPack.inTensors = {inTensorX};
    variantPack.outTensors = {outTensor1, outTensor2};

    // 申请SplitOp
    atb::Operation *splitOp = {nullptr};
    CHECK_STATUS(atb::CreateOperation(opParam, &splitOp));
    uint64_t workspaceSize = 0;
    // ATB Operation 第一阶段接口调用:对输入输出进行检查,并根据需要计算workspace大小
    CHECK_STATUS(splitOp->Setup(variantPack, workspaceSize, context));
    uint8_t *workspacePtr = nullptr;
    if (workspaceSize > 0) {
        CHECK_STATUS(aclrtMalloc((void **)(&workspacePtr), workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST));
    }
    // ATB Operation 第二阶段接口调用:执行算子
    CHECK_STATUS(splitOp->Execute(variantPack, workspacePtr, workspaceSize, context));
    CHECK_STATUS(aclrtSynchronizeStream(stream));  // 流同步,等待device侧任务计算完成

    // 资源释放
    for (atb::Tensor &inTensor : variantPack.inTensors) {
        CHECK_STATUS(aclrtFree(inTensor.deviceData));
    }
    for (atb::Tensor &outTensor : variantPack.outTensors) {
        CHECK_STATUS(aclrtFree(outTensor.deviceData));
    }
    if (workspaceSize > 0) {
        CHECK_STATUS(aclrtFree(workspacePtr));
    }
    CHECK_STATUS(atb::DestroyOperation(splitOp));
    CHECK_STATUS(aclrtDestroyStream(stream));
    CHECK_STATUS(atb::DestroyContext(context));
    CHECK_STATUS(aclFinalize());
    std::cout << "Split demo success!" << std::endl;
    return 0;
}