从输入张量中根据索引收集切片,并将这些切片组合成一个新的张量。

1 2 3 4 5 | struct GatherParam { int64_t axis = 0; int64_t batchDims = 0; uint8_t rsv[16] = {0}; }; |
成员名称 |
类型 |
默认值 |
描述 |
|---|---|---|---|
axis |
int64_t |
0 |
指定要收集切片的轴。默认值为0。 “axis”必须大于或等于0。 |
batchDims |
int64_t |
0 |
代表批处理的维度数。表示可以从每轮批处理的元素中分别取出满足要求的切片数据。例如,如果batchDims=1,则代表在params的第一个轴上有一个外循环indices,见示例2。 "batchDims"必须大于或等于0,且小于或等于axis。 |
rsv[16] |
uint8_t |
{0} |
预留参数。 |
参数 |
维度 |
数据类型 |
格式 |
描述 |
|---|---|---|---|---|
x |
[dim_0,dim_1,...,dim_n] |
float16/float/bf16/int32/uint32 |
ND |
输入tensor。 |
indices |
[dim_0,dim_1,...,dim_n] |
int64/int32/uint32 |
ND |
索引表,值必须在[0, x.shape[axis]]范围内,x与indices的维数之和小于等于9。 indices的维数必须大于等于“batchdims”。 |
参数 |
维度 |
数据类型 |
格式 |
描述 |
|---|---|---|---|---|
output |
[dim_0,dim_1,...,dim_n] |
float16/float/bf16/int32/uint32 |
ND |
输出tensor。数据类型与x保持一致。 |
axis = 1;
batchDims = 0;
输入tensor为:
x= [[1,2,3],
[4,5,6],
[7,8,9]]
indices tensor为:
indices=[2,0]
根据indices tensor的值,在axis轴获取params数据切片,output tensor为:
output=[[3, 1],
[6, 4],
[9, 7]]
axis= 1;
batchDims = 1;
输入tensor为:
x= [[1,2,3],
[4,5,6],
[7,8,9]]
indices tensor为:
indices=[[1],
[2],
[0]]
因为batch_dims=1,则代表在第一个轴上(即轴0)进行批处理。在轴0上,将x[i]和indices[i]进行一一对应的gather处理,根据indices的值,在axis轴获取x的数据切片, 其中i为batch轴的坐标,output tensor为:
output= [[2,],
[6,],
[7,]]