ConcatOperation
功能
将两个输入张量在指定维度拼接成一个输出张量。
算子上下文
图1 ConcatOperation


算子功能实现
Concat 算子的功能是将两个输入张量在指定维度拼接成一个输出张量。
图2 Concat算子实现原理


定义
struct ConcatParam { int concatDim = 0; };
参数列表
成员名称 |
类型 |
默认值 |
描述 |
---|---|---|---|
concatDim |
int |
0 |
指定拼接的维度索引。 假设拼接的维度索引为k,x维度数为dimNum:
|
输入
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
x |
[x_dim_0,x_dim_1,... ,x_dim_n] |
float16/bf16 |
ND |
需要被拼接的tensor。 |
y |
[y_dim_0,y_dim_1,... ,y_dim_n] |
float16/bf16 |
ND |
需要被拼接的tensor。除了y_dim_k之外,其余维度与x相同。 |
输出
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
output |
[dim_0,dim_1,... ,dim_n] |
float16/bf16 |
ND |
拼接后的tensor。 |
约束
- Param参数约束
- dimNum为被拼接tensor的维数。
- -dimNum ≤ concatDim ≤ dimNum - 1。
- 输入约束
- 输入x和y的维数相等。
- 输入x和y的维度大小,除了cancatDim维之外,其他维度要求相同。
接口调用示例
- 输入
x shape = [3, 2, 3]: [[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]] y shape = [3, 1, 3]: [[[2.0, 2.0, 2.0]], [[2.0, 2.0, 2.0]], [[2.0, 2.0, 2.0]]] param.concatDim = 1(或-2)
- 输出
output shape = [3, 3, 3] [[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [2.0, 2.0, 2.0]], [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [2.0, 2.0, 2.0]], [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]]