从输入张量中根据索引收集切片,并将这些切片组合成一个新的张量。
硬件型号 |
支持情况 |
特殊说明 |
---|---|---|
支持 |
输入x数据类型只支持float16。 |
|
支持 |
- |
|
支持 |
- |
|
支持 |
- |
1 2 3 4 5 | struct GatherParam { int64_t axis = 0; int64_t batchDims = 0; uint8_t rsv[16] = {0}; }; |
成员名称 |
类型 |
默认值 |
取值范围 |
是否必选 |
描述 |
---|---|---|---|---|---|
axis |
int64_t |
0 |
≥0 |
是 |
指定要收集切片的轴。默认值为0。 |
batchDims |
int64_t |
0 |
≥0且≤axis |
是 |
代表批处理的维度数。表示可以从每轮批处理的元素中分别取出满足要求的切片数据。 例如,如果batchDims=1,则代表在x的第(axis - batchDims)轴上有一个外循环,见示例2。 |
rsv[16] |
uint8_t |
{0} |
[0] |
否 |
预留参数。 |
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
x |
[dim_0,dim_1,...,dim_n] |
float16/float/bf16/int32/uint32/int64 |
ND |
输入tensor。 |
indexs |
[dim_0,dim_1,...,dim_n] |
int64/int32/uint32 |
ND |
索引表,值必须在[0, x.shape[axis]]范围内,x与indexs的维数之和小于等于9。 indexs的维数必须大于等于“batchdims”。 |
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
output |
[dim_0,dim_1,...,dim_n] |
float16/float/bf16/int32/uint32/int64 |
ND |
输出tensor。数据类型与x保持一致。 |
axis = 1; batchDims = 0; 输入tensor为: x= [[1,2,3], [4,5,6], [7,8,9]] indices tensor为: indices=[2,0] 根据indices tensor的值,在axis轴获取x数据切片,output tensor为: output=[[3, 1], [6, 4], [9, 7]]
axis= 1; batchDims = 1; 输入tensor为: x= [[1,2,3], [4,5,6], [7,8,9]] indices tensor为: indices=[[1], [2], [0]] 因为batch_dims=1,则代表在第一个轴上(即轴0)进行批处理。在轴0上,将x[i]和indices[i]进行一一对应的gather处理,根据indices的值,在axis轴获取x的数据切片, 其中i为batch轴的坐标,output tensor为: output= [[2,], [6,], [7,]] 上面等价于: def gather_fun(x, indices, axis): batch_dims=1 res= [] # 进行外循环 for p,i in zip(x, indices): r = tf.gather(p, i, axis=axis-batch_dims) res.append(r) return tf.stack(res)
前置条件和编译命令请参见算子调用示例。
场景:基础场景。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 | #include <iostream> #include <vector> #include <numeric> #include "acl/acl.h" #include "atb/operation.h" #include "atb/types.h" #include "atb/atb_infer.h" #include "demo_util.h" /** * @brief 准备atb::VariantPack中的所有输入tensor * @param contextPtr context指针 * @param stream stream * @param seqLenHost host侧tensor。序列长度向量,等于1时,为增量或全量;大于1时,为全量 * @param tokenOffsetHost host侧tensor。计算完成后的token偏移 * @param layerId layerId,取cache的kv中哪一个kv进行计算 * @return atb::SVector<atb::Tensor> atb::VariantPack中的输入tensor * @note 需要传入所有host侧tensor */ atb::SVector<atb::Tensor> PrepareInTensor(atb::Context *contextPtr, aclrtStream stream) { uint32_t dim0 = 3; uint32_t dim1 = 3; // 创建tensor0 std::vector<float> tensorzero{1, 2, 3, 4, 5, 6, 7, 8, 9}; atb::Tensor tensorZero = CreateTensorFromVector(contextPtr, stream, tensorzero, ACL_FLOAT16, aclFormat::ACL_FORMAT_ND, {dim0, dim1}); // 创建tensor1 std::vector<int64_t> tensorone{2, 1}; atb::Tensor tensorOne = CreateTensorFromVector(contextPtr, stream, tensorone, ACL_INT64, aclFormat::ACL_FORMAT_ND, {2}); // 根据顺序将所有输入tensor放入SVector atb::SVector<atb::Tensor> inTensors = {tensorZero, tensorOne}; return inTensors; } /** * @brief 创建一个Gather的Operation,并设置参数 * @return atb::Operation * 返回一个Operation指针 */ atb::Operation *PrepareOperation() { atb::infer::GatherParam gatherParam; gatherParam.axis = 0; gatherParam.batchDims = 0; atb::Operation *op = nullptr; CHECK_STATUS(atb::CreateOperation(gatherParam, &op)); return op; } int main(int argc, char **argv) { // 1.设置卡号、创建context、设置stream CHECK_STATUS(aclInit(nullptr)); int32_t deviceId = 0; CHECK_STATUS(aclrtSetDevice(deviceId)); atb::Context *context = nullptr; CHECK_STATUS(atb::CreateContext(&context)); void *stream = nullptr; CHECK_STATUS(aclrtCreateStream(&stream)); context->SetExecuteStream(stream); // Gather示例 atb::Operation *op = PrepareOperation(); // 准备输入张量 atb::VariantPack variantPack; variantPack.inTensors = PrepareInTensor(context, stream); // 放入输入tensor atb::Tensor tensorOut = CreateTensor(ACL_FLOAT16, aclFormat::ACL_FORMAT_ND, {2, 3}); // 创建输出tensor variantPack.outTensors.push_back(tensorOut); // 放入输出tensor // setup阶段,计算workspace大小 uint64_t workspaceSize = 0; CHECK_STATUS(op->Setup(variantPack, workspaceSize, context)); uint8_t *workspacePtr = nullptr; if (workspaceSize > 0) { CHECK_STATUS(aclrtMalloc((void **)(&workspacePtr), workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST)); } // execute阶段 op->Execute(variantPack, workspacePtr, workspaceSize, context); CHECK_STATUS(aclrtSynchronizeStream(stream)); // 流同步,等待device侧任务计算完成 // 释放内存 for (atb::Tensor &inTensor : variantPack.inTensors) { CHECK_STATUS(aclrtFree(inTensor.deviceData)); } if (workspaceSize > 0) { CHECK_STATUS(aclrtFree(workspacePtr)); } // 资源释放 CHECK_STATUS(atb::DestroyOperation(op)); // operation,对象概念,先释放 CHECK_STATUS(aclrtDestroyStream(stream)); CHECK_STATUS(atb::DestroyContext(context)); // context,全局资源,后释放 std::cout << "Gather demo success!" << std::endl; return 0; } |