将输入张量沿指定维度切分成多个张量。
参数 |
描述 |
---|---|
x |
输入的张量,即希望分割的张量。 |
splitNum |
是一个整数,表示等分次数,当前支持2或3。 |
splitDim |
可选,默认0,沿着这个维度进行拆分。默认情况下,沿着第一个维度(即批次维度)进行拆分。 |
返回若干张量。这些张量是从原张量中根据指定的方式切分出来的。
1 2 3 4 5 6 | struct SplitParam { int32_t splitDim = 0; int32_t splitNum = 2; SVector<int32_t> splitSizes = {}; uint8_t rsv[8] = {0}; }; |
成员名称 |
类型 |
默认值 |
描述 |
---|---|---|---|
splitDim |
int32_t |
0 |
指定切分的维度索引。 splitDim须位于输入张量x的维度范围内,即如果x的维度为xDim,则splitDim的取值范围为[-xDim, xDim - 1]。 当splitDim为负数时,其含义是从最高维度开始访问,如splitDim = -1,x维度数为dimNum,则拆分维度为dimNum - 1。 |
splitNum |
int32_t |
2 |
等分次数,当前支持2或3。 输入张量x的维度须能够被splitNum整除,且当splitNum = 3时输入x要求是float16或者bf16数据类型。 |
splitSizes |
SVector<int32_t> |
- |
指定每个输出tensor在切分维度上的大小,不传入此参数时使用等长切分,传入此参数时使用splitV不等长切分。 |
rsv[8] |
uint8_t |
{0} |
预留参数。 |
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
x |
[dim_0, ..., dim_splitDim, ..., dim_n] |
float16/int64/bf16 |
ND |
输入,最高支持8维。 |
output1 |
|
float16/int64/bf16 |
ND |
输出,切分后的tensor。数据类型与x一致。 |
output2 |
|
float16/int64/bf16 |
ND |
输出,切分后的tensor。数据类型与x一致。 |
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
x |
[dim_0, ..., dim_splitDim,...,dim_n] |
float16/bf16 |
ND |
输入,最高支持8维。 |
output1 |
|
float16/bf16 |
ND |
输出,切分后的tensor。数据类型与x一致。 |
output2 |
|
float16/bf16 |
ND |
输出,切分后的tensor。数据类型与x一致。 |
output3 |
|
float16/bf16 |
ND |
输出,切分后的tensor。数据类型与x一致。 |
输入:
splitDim = 0 splitNum = 3 x = [3, 3, 3, 3, 3, 3, 3, 3, 3]
输出:
z = [3, 3, 3] z1 = [3, 3, 3] z2 = [3, 3, 3]
前置条件和编译命令请参见算子调用示例。
场景:基础场景。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 | #include <iostream> #include <vector> #include <numeric> #include <random> #include "acl/acl.h" #include "atb/operation.h" #include "atb/types.h" #include "atb/atb_infer.h" #include "demo_util.h" int main(int argc, char **argv) { // 设置卡号、创建context、设置stream CHECK_STATUS(aclInit(nullptr)); int32_t deviceId = 0; CHECK_STATUS(aclrtSetDevice(deviceId)); atb::Context *context = nullptr; CHECK_STATUS(atb::CreateContext(&context)); void *stream = nullptr; CHECK_STATUS(aclrtCreateStream(&stream)); context->SetExecuteStream(stream); // 配置Op参数 atb::infer::SplitParam opParam; opParam.splitDim = 1; // 设定切分轴为1 opParam.splitNum = 2; // 设置切分后得到的块数 opParam.splitSizes = {2, 3}; // 设置不均匀切分时每块大小 // 准备VariantPack atb::VariantPack variantPack; std::vector<int64_t> inputXShape = {1, 5, 2}; std::vector<int64_t> output1Shape = {1, 2, 2}; std::vector<int64_t> output2Shape = {1, 3, 2}; std::vector<float> inTensorXData = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; std::vector<float> output1RefData = {0, 1, 2, 3}; std::vector<float> output2RefData = {4, 5, 6, 7, 8, 9}; atb::Tensor inTensorX = CreateTensorFromVector(context, stream, inTensorXData, ACL_FLOAT16, aclFormat::ACL_FORMAT_ND, inputXShape); atb::Tensor outTensor1 = CreateTensor(ACL_FLOAT16, aclFormat::ACL_FORMAT_ND, output1Shape); atb::Tensor outTensor2 = CreateTensor(ACL_FLOAT16, aclFormat::ACL_FORMAT_ND, output2Shape); variantPack.inTensors = {inTensorX}; variantPack.outTensors = {outTensor1, outTensor2}; // 申请SplitOp atb::Operation *splitOp = {nullptr}; CHECK_STATUS(atb::CreateOperation(opParam, &splitOp)); uint64_t workspaceSize = 0; // ATB Operation 第一阶段接口调用:对输入输出进行检查,并根据需要计算workspace大小 CHECK_STATUS(splitOp->Setup(variantPack, workspaceSize, context)); uint8_t *workspacePtr = nullptr; if (workspaceSize > 0) { CHECK_STATUS(aclrtMalloc((void **)(&workspacePtr), workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST)); } // ATB Operation 第二阶段接口调用:执行算子 CHECK_STATUS(splitOp->Execute(variantPack, workspacePtr, workspaceSize, context)); CHECK_STATUS(aclrtSynchronizeStream(stream)); // 流同步,等待device侧任务计算完成 // 资源释放 for (atb::Tensor &inTensor : variantPack.inTensors) { CHECK_STATUS(aclrtFree(inTensor.deviceData)); } for (atb::Tensor &outTensor : variantPack.outTensors) { CHECK_STATUS(aclrtFree(outTensor.deviceData)); } if (workspaceSize > 0) { CHECK_STATUS(aclrtFree(workspacePtr)); } CHECK_STATUS(atb::DestroyOperation(splitOp)); CHECK_STATUS(aclrtDestroyStream(stream)); CHECK_STATUS(atb::DestroyContext(context)); CHECK_STATUS(aclFinalize()); std::cout << "Split demo success!" << std::endl; return 0; } |