昇腾社区首页
中文
注册

GatherOperation

功能

从输入张量中根据索引收集切片,并将这些切片组合成一个新的张量。

图1 GatherOperation算子上下文

定义

struct GatherParam {
    int64_t axis = 0;
    int64_t batchDims = 0;
};

参数列表

成员名称

类型

默认值

描述

axis

int64_t

0

指定要收集切片的轴。默认值为0。

“axis”必须大于或等于0。

batchDims

int64_t

0

代表批处理的维度数。表示可以从每轮批处理的元素中分别取出满足要求的切片数据。例如,如果batchDims=1,则代表在params的第一个轴上有一个外循环indices,见示例2。

"batchDims"必须大于或等于0,且小于或等于axis。

输入

参数

维度

数据类型

格式

描述

x

[dim_0,dim_1,...,dim_n]

float16/float/bf16/int32/uint32

ND

输入tensor。

indexs

[dim_0,dim_1,...,dim_n]

int64/int32/uint32

ND

索引表,值必须在[0, x.shape[axis]]范围内,x与indexs的维数之和小于等于9。

indexs的维数必须大于等于“batchdims”

输出

参数

维度

数据类型

格式

描述

output

[dim_0,dim_1,...,dim_n]

float16/float/bf16/int32/uint32

ND

输出tensor。

规格约束

  • 索引表indices值必须在[0, x.shape[axis]]范围内,x与indexs的维数之和小于等于9。indexs的维数必须大于等于"batchdims"。
  • 成员"axis" 必须大于等于0。
  • "batchDims"必须大于或等于0,且小于或等于axis。

接口调用示例

  • 示例1:
    axis = 1;
    batchDims = 0;
    输入tensor为:
     x=  [[1,2,3],
          [4,5,6],
          [7,8,9]]
    indices tensor为:
    indices=[2,0]
    根据indices tensor的值,在axis轴获取params数据切片,output tensor为:
    output=[[3, 1],
            [6, 4],
            [9, 7]]
  • 示例2:
    axis= 1;
    batchDims = 1;
    输入tensor为:
     x=  [[1,2,3],
          [4,5,6],
          [7,8,9]]
    indices tensor为:
    indices=[[1],
             [2],
             [0]]
    因为batch_dims=1,则代表在第一个轴上(即轴0)进行批处理。在轴0上,将x[i]和indices[i]进行一一对应的gather处理,根据indices的值,在axis轴获取x的数据切片, 其中i为batch轴的坐标,output tensor为:
    output= [[2,],
            [6,],
            [7,]]