SplitOperation
产品支持情况
硬件型号 |
是否支持 |
---|---|
√ |
|
√ |
|
√ |
|
√ |
|
√ |
功能说明
将输入张量沿指定维度切分成多个张量。
算子上下文

算子功能实现
- 参数:
参数
描述
x
输入的张量,即希望分割的张量。
splitNum
是一个整数,表示等分次数,当前支持2或3。
splitDim
可选,默认0,沿着这个维度进行拆分。默认情况下,沿着第一个维度(即批次维度)进行拆分。
- 返回值:
返回若干张量。这些张量是从原张量中根据指定的方式切分出来的。
定义
1 2 3 4 5 6 | struct SplitParam { int32_t splitDim = 0; int32_t splitNum = 2; SVector<int32_t> splitSizes = {}; uint8_t rsv[8] = {0}; }; |
参数列表
成员名称 |
类型 |
默认值 |
描述 |
---|---|---|---|
splitDim |
int32_t |
0 |
指定切分的维度索引。 splitDim须位于输入张量x的维度范围内,即如果x的维度为xDim,则splitDim的取值范围为[-xDim, xDim - 1]。 当splitDim为负数时,其含义是从最高维度开始访问,如splitDim = -1,x维度数为dimNum,则拆分维度为dimNum - 1。 |
splitNum |
int32_t |
2 |
等分次数,当前支持2或3。 输入张量x的维度须能够被splitNum整除,且当splitNum = 3时输入x要求是float16或者bf16数据类型。 |
splitSizes |
SVector<int32_t> |
- |
指定每个输出tensor在切分维度上的大小,不传入此参数时使用等长切分,传入此参数时使用不等长切分。 |
rsv[8] |
uint8_t |
{0} |
预留参数。 |
splitNum=2时输入输出
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
x |
[dim_0, ..., dim_splitDim, ..., dim_n] |
float16/int64/bf16 |
ND |
输入,最高支持8维。 |
output1 |
|
float16/int64/bf16 |
ND |
输出,切分后的tensor。数据类型与x一致。 |
output2 |
|
float16/int64/bf16 |
ND |
输出,切分后的tensor。数据类型与x一致。 |
splitNum=3时输入输出
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
x |
[dim_0, ..., dim_splitDim,...,dim_n] |
float16/bf16 |
ND |
输入,最高支持8维。 |
output1 |
|
float16/bf16 |
ND |
输出,切分后的tensor。数据类型与x一致。 |
output2 |
|
float16/bf16 |
ND |
输出,切分后的tensor。数据类型与x一致。 |
output3 |
|
float16/bf16 |
ND |
输出,切分后的tensor。数据类型与x一致。 |
约束说明
- splitDim必须位于输入张量x的维度范围内,即如果x的维度为xDim,则splitDim的取值范围为[-xDim, xDim - 1]。 当splitDim为负数时,其含义是从最高维度开始访问,如splitDim = -1,x维度数为dimNum,则拆分维度为dimNum - 1。使用不等长切分时不支持负数索引。
- splitSizes非空时,维度为splitNum,其每一个元素要求大于等于1。元素之和等于dim_splitDim的大小。
- splitSizes为空时,输入张量x的维度必须能够被splitNum整除。
- 当splitNum = 3时,要求输入x的数据类型是float16或者bf16。
- bf16数据类型仅支持
Atlas A2 训练系列产品 /Atlas 800I A2 推理产品 和Atlas A3 推理系列产品 /Atlas A3 训练系列产品 。
接口调用示例
输入:
splitDim = 0 splitNum = 3 x = [3, 3, 3, 3, 3, 3, 3, 3, 3]
输出:
z = [3, 3, 3] z1 = [3, 3, 3] z2 = [3, 3, 3]