ConcatOperation

功能

将两个输入张量在指定维度拼接成一个输出张量。

算子上下文

图1 ConcatOperation

算子功能实现

Concat 算子的功能是将两个输入张量在指定维度拼接成一个输出张量。

图2 Concat算子实现原理

定义

1
2
3
4
struct ConcatParam {
    int concatDim = 0;
    uint8_t rsv[12] = {0};
};

参数列表

成员名称

类型

默认值

描述

concatDim

int

0

指定拼接的维度索引。

假设拼接的维度索引为k,x维度数为dimNum:

  • 当concatDim为负数时,其含义是从最高维度开始访问,k=dimNum+concatDim。例如concatDim = -1,则拼接维度k为dimNum - 1。
  • 否则k=concatDim。

rsv[12]

uint8_t

{0}

预留参数。

输入

参数

维度

数据类型

格式

描述

x

[x_dim_0,x_dim_1,... ,x_dim_n]

float16/bf16

ND

需要被拼接的tensor。

y

[y_dim_0,y_dim_1,... ,y_dim_n]

float16/bf16

ND

需要被拼接的tensor。除了y_dim_k之外,其余维度与x相同。

输出

参数

维度

数据类型

格式

描述

output

[dim_0, dim_1, ..., dim_n]

float16/bf16

ND

拼接后的tensor。

规格约束

接口调用示例

算子调用示例(C++)

前置条件和编译命令请参见算子调用示例

场景:基础场景。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#include <iostream>
#include <vector>
#include <numeric>
#include <random>
#include "acl/acl.h"
#include "atb/operation.h"
#include "atb/types.h"
#include "atb/atb_infer.h"
#include "demo_util.h"
int main(int argc, char **argv)
{
    // 设置卡号、创建context、设置stream
    CHECK_STATUS(aclInit(nullptr));
    int32_t deviceId = 0;
    CHECK_STATUS(aclrtSetDevice(deviceId));
    atb::Context *context = nullptr;
    CHECK_STATUS(atb::CreateContext(&context));
    void *stream = nullptr;
    CHECK_STATUS(aclrtCreateStream(&stream));
    context->SetExecuteStream(stream);
    // 配置Op参数
    atb::infer::ConcatParam opParam;
    opParam.concatDim = 1;  // 设定拼接轴为1
    // 准备VariantPack
    atb::VariantPack variantPack;
    std::vector<int64_t> inputXShape = {1, 3, 2};
    std::vector<int64_t> inputYShape = {1, 2, 2};
    std::vector<int64_t> outputShape = {1, 5, 2};
    std::vector<float> inTensorXData = {0, 1, 2, 3, 4, 5};
    std::vector<float> inTensorYData = {6, 7, 8, 9};
    std::vector<float> outputRefData = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
    atb::Tensor inTensorX =
        CreateTensorFromVector(context, stream, inTensorXData, ACL_FLOAT16, aclFormat::ACL_FORMAT_ND, inputXShape);
    atb::Tensor inTensorY =
        CreateTensorFromVector(context, stream, inTensorYData, ACL_FLOAT16, aclFormat::ACL_FORMAT_ND, inputYShape);
    atb::Tensor outTensor = CreateTensor(ACL_FLOAT16, aclFormat::ACL_FORMAT_ND, outputShape);
    variantPack.inTensors = {inTensorX, inTensorY};
    variantPack.outTensors = {outTensor};
    // 申请ConcatOp
    atb::Operation *concatOp = {nullptr};
    CHECK_STATUS(atb::CreateOperation(opParam, &concatOp));
    uint64_t workspaceSize = 0;
    // ATB Operation 第一阶段接口调用:对输入输出进行检查,并根据需要计算workspace大小
    CHECK_STATUS(concatOp->Setup(variantPack, workspaceSize, context));
    uint8_t *workspacePtr = nullptr;
    if (workspaceSize > 0) {
        CHECK_STATUS(aclrtMalloc((void **)(&workspacePtr), workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST));
    }
    // ATB Operation 第二阶段接口调用:执行算子
    CHECK_STATUS(concatOp->Execute(variantPack, workspacePtr, workspaceSize, context));
    CHECK_STATUS(aclrtSynchronizeStream(stream));  // 流同步,等待device侧任务计算完成
    // 资源释放
    for (atb::Tensor &inTensor : variantPack.inTensors) {
        CHECK_STATUS(aclrtFree(inTensor.deviceData));
    }
    for (atb::Tensor &outTensor : variantPack.outTensors) {
        CHECK_STATUS(aclrtFree(outTensor.deviceData));
    }
    if (workspaceSize > 0) {
        CHECK_STATUS(aclrtFree(workspacePtr));
    }
    CHECK_STATUS(atb::DestroyOperation(concatOp));
    CHECK_STATUS(aclrtDestroyStream(stream));
    CHECK_STATUS(atb::DestroyContext(context));
    CHECK_STATUS(aclFinalize());
    std::cout << "Concat demo success!" << std::endl;
    return 0;
}