昇腾社区首页
中文
注册
开发者
下载

SplitOperation

产品支持情况

硬件型号

是否支持

Atlas A3 推理系列产品/Atlas A3 训练系列产品

Atlas A2 训练系列产品/Atlas 800I A2 推理产品

Atlas 训练系列产品

Atlas 推理系列产品

Atlas 200I/500 A2 推理产品

功能说明

将输入张量沿指定维度切分成多个张量。

算子上下文

图1 SplitOperation

算子功能实现

  • 参数:

    参数

    描述

    x

    输入的张量,即希望分割的张量。

    splitNum

    是一个整数,表示等分次数,当前支持2或3。

    splitDim

    可选,默认0,沿着这个维度进行拆分。默认情况下,沿着第一个维度(即批次维度)进行拆分。

  • 返回值:

    返回若干张量。这些张量是从原张量中根据指定的方式切分出来的。

定义

1
2
3
4
5
6
struct SplitParam {
    int32_t splitDim = 0;
    int32_t splitNum = 2;
    SVector<int32_t> splitSizes = {};
    uint8_t rsv[8] = {0};
};

参数列表

成员名称

类型

默认值

描述

splitDim

int32_t

0

指定切分的维度索引。

splitDim须位于输入张量x的维度范围内,即如果x的维度为xDim,则splitDim的取值范围为[-xDim, xDim - 1]。 当splitDim为负数时,其含义是从最高维度开始访问,如splitDim = -1,x维度数为dimNum,则拆分维度为dimNum - 1。

splitNum

int32_t

2

等分次数,当前支持2或3。

输入张量x的维度须能够被splitNum整除,且当splitNum = 3时输入x要求是float16或者bf16数据类型。

splitSizes

SVector<int32_t>

-

指定每个输出tensor在切分维度上的大小,不传入此参数时使用等长切分,传入此参数时使用不等长切分。

rsv[8]

uint8_t

{0}

预留参数。

splitNum=2时输入输出

参数

维度

数据类型

格式

描述

x

[dim_0, ..., dim_splitDim, ..., dim_n]

float16/int64/bf16

ND

输入,最高支持8维。

output1

  • [dim_0, …, dim_splitDim_1, ..., dim_n]
  • 当splitSizes为空时,dim_splitDim_1=dim_splitDim/2;
  • 当splitSizes不为空时,dim_splitDim_2=splitSizes[0]。

float16/int64/bf16

ND

输出,切分后的tensor。数据类型与x一致。

output2

  • [dim_0, …, dim_splitDim_2, ..., dim_n]
  • 当splitSizes为空时,dim_splitDim_2=dim_splitDim/2;
  • 当splitSizes不为空时,dim_splitDim_2=splitSizes[1]。

float16/int64/bf16

ND

输出,切分后的tensor。数据类型与x一致。

splitNum=3时输入输出

参数

维度

数据类型

格式

描述

x

[dim_0, ..., dim_splitDim,...,dim_n]

float16/bf16

ND

输入,最高支持8维。

output1

  • [dim_0, …, dim_splitDim_1, ..., dim_n]
  • 当splitSizes为空时,dim_splitDim_1=dim_splitDim/2;
  • 当splitSizes不为空时,dim_splitDim_1=splitSizes[0]。

float16/bf16

ND

输出,切分后的tensor。数据类型与x一致。

output2

  • [dim_0, …, dim_splitDim_2, ..., dim_n]
  • 当splitSizes为空时,dim_splitDim_2=dim_splitDim/3;
  • 当splitSizes不为空时,dim_splitDim_2=splitSizes[1]。

float16/bf16

ND

输出,切分后的tensor。数据类型与x一致。

output3

  • [dim_0, …, dim_splitDim_3, ..., dim_n]
  • 当splitSizes为空时,dim_splitDim_3=dim_splitDim/3;
  • 当splitSizes不为空时,dim_splitDim_3=splitSizes[2]。

float16/bf16

ND

输出,切分后的tensor。数据类型与x一致。

约束说明

  • splitDim必须位于输入张量x的维度范围内,即如果x的维度为xDim,则splitDim的取值范围为[-xDim, xDim - 1]。 当splitDim为负数时,其含义是从最高维度开始访问,如splitDim = -1,x维度数为dimNum,则拆分维度为dimNum - 1。使用不等长切分时不支持负数索引。
  • splitSizes非空时,维度为splitNum,其每一个元素要求大于等于1。元素之和等于dim_splitDim的大小。
  • splitSizes为空时,输入张量x的维度必须能够被splitNum整除。
  • 当splitNum = 3时,要求输入x的数据类型是float16或者bf16。
  • bf16数据类型仅支持Atlas A2 训练系列产品/Atlas 800I A2 推理产品Atlas A3 推理系列产品/Atlas A3 训练系列产品

接口调用示例

输入:

splitDim = 0
splitNum = 3
x = [3, 3, 3, 3, 3, 3, 3, 3, 3]

输出:

z = [3, 3, 3]
z1 = [3, 3, 3]
z2 = [3, 3, 3]