昇腾社区首页
中文
注册

ConcatOperation

功能

将两个输入张量在指定维度拼接成一个输出张量。

算子上下文

图1 ConcatOperation

算子功能实现

Concat 算子的功能是将两个输入张量在指定维度拼接成一个输出张量。

图2 Concat算子实现原理

定义

struct ConcatParam {
    int concatDim = 0;
};

参数列表

成员名称

类型

默认值

描述

concatDim

int

0

指定拼接的维度索引。

假设拼接的维度索引为k,x维度数为dimNum:

  • 当concatDim为负数时,其含义是从最高维度开始访问,k=dimNum+concatDim。例如concatDim = -1,则拼接维度k为dimNum - 1。
  • 否则k=concatDim。

输入

参数

维度

数据类型

格式

描述

x

[x_dim_0,x_dim_1,... ,x_dim_n]

float16/bf16

ND

需要被拼接的tensor。

y

[y_dim_0,y_dim_1,... ,y_dim_n]

float16/bf16

ND

需要被拼接的tensor。除了y_dim_k之外,其余维度与x相同。

输出

参数

维度

数据类型

格式

描述

output

[dim_0,dim_1,... ,dim_n]

float16/bf16

ND

拼接后的tensor。

约束

  • Param参数约束
    • dimNum为被拼接tensor的维数。
    • -dimNum ≤ concatDim ≤ dimNum - 1。
  • 输入约束
    • 输入x和y的维数相等。
    • 输入x和y的维度大小,除了cancatDim维之外,其他维度要求相同。

接口调用示例

  • 输入
    x shape = [3, 2, 3]:
    [[[1.0, 1.0, 1.0],
      [1.0, 1.0, 1.0]],
     [[1.0, 1.0, 1.0],
      [1.0, 1.0, 1.0]],
     [[1.0, 1.0, 1.0],
      [1.0, 1.0, 1.0]]]
    y shape = [3, 1, 3]:
    [[[2.0, 2.0, 2.0]],
     [[2.0, 2.0, 2.0]],
     [[2.0, 2.0, 2.0]]]
    param.concatDim = 1(或-2)
  • 输出
    output shape = [3, 3, 3]
    [[[1.0, 1.0, 1.0],
      [1.0, 1.0, 1.0],
      [2.0, 2.0, 2.0]],
     [[1.0, 1.0, 1.0],
      [1.0, 1.0, 1.0],
      [2.0, 2.0, 2.0]],
     [[1.0, 1.0, 1.0],
      [1.0, 1.0, 1.0],
      [2.0, 2.0, 2.0]]]