从输入张量中根据索引收集切片,并将这些切片组合成一个新的张量。
struct GatherParam { int64_t axis = 0; int64_t batchDims = 0; };
成员名称 |
类型 |
默认值 |
描述 |
---|---|---|---|
axis |
int64_t |
0 |
指定要收集切片的轴。默认值为0。 “axis”必须大于或等于0。 |
batchDims |
int64_t |
0 |
代表批处理的维度数。表示可以从每轮批处理的元素中分别取出满足要求的切片数据。例如,如果batchDims=1,则代表在params的第一个轴上有一个外循环indices,见示例2。 "batchDims"必须大于或等于0,且小于或等于axis。 |
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
x |
[dim_0,dim_1,...,dim_n] |
float16/float/bf16/int32/uint32 |
ND |
输入tensor。 |
indexs |
[dim_0,dim_1,...,dim_n] |
int64/int32/uint32 |
ND |
索引表,值必须在[0, x.shape[axis]]范围内,x与indexs的维数之和小于等于9。 indexs的维数必须大于等于“batchdims”。 |
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
output |
[dim_0,dim_1,...,dim_n] |
float16/float/bf16/int32/uint32 |
ND |
输出tensor。 |
axis = 1; batchDims = 0; 输入tensor为: x= [[1,2,3], [4,5,6], [7,8,9]] indices tensor为: indices=[2,0] 根据indices tensor的值,在axis轴获取params数据切片,output tensor为: output=[[3, 1], [6, 4], [9, 7]]
axis= 1; batchDims = 1; 输入tensor为: x= [[1,2,3], [4,5,6], [7,8,9]] indices tensor为: indices=[[1], [2], [0]] 因为batch_dims=1,则代表在第一个轴上(即轴0)进行批处理。在轴0上,将x[i]和indices[i]进行一一对应的gather处理,根据indices的值,在axis轴获取x的数据切片, 其中i为batch轴的坐标,output tensor为: output= [[2,], [6,], [7,]]