ConcatOperation

功能

将两个输入张量在指定维度拼接成一个输出张量。

算子上下文

图1 ConcatOperation

算子功能实现

Concat 算子的功能是将两个输入张量在指定维度拼接成一个输出张量。

图2 Concat算子实现原理

定义

struct ConcatParam {
    int concatDim = 0;
};

参数列表

成员名称

类型

默认值

描述

concatDim

int

0

指定拼接的维度索引。

假设拼接的维度索引为k,x维度数为dimNum:

  • 当concatDim为负数时,其含义是从最高维度开始访问,k=dimNum+concatDim。例如concatDim = -1,则拼接维度k为dimNum - 1。
  • 否则k=concatDim。

输入

参数

维度

数据类型

格式

描述

x

[x_dim_0,x_dim_1,... ,x_dim_n]

float16/bf16

ND

需要被拼接的tensor。

y

[y_dim_0,y_dim_1,... ,y_dim_n]

float16/bf16

ND

需要被拼接的tensor。除了y_dim_k之外,其余维度与x相同。

输出

参数

维度

数据类型

格式

描述

output

[dim_0,dim_1,... ,dim_n]

float16/bf16

ND

拼接后的tensor。

约束

接口调用示例