GatherOperation
产品支持情况
硬件型号  | 
是否支持  | 
特殊说明  | 
|---|---|---|
√  | 
-  | 
|
√  | 
-  | 
|
√  | 
-  | 
|
√  | 
-  | 
|
√  | 
输入x数据类型只支持float16。  | 
功能说明
从输入张量中根据索引收集切片,并将这些切片组合成一个新的张量。
图1 GatherOperation算子上下文


定义
1 2 3 4 5  | struct GatherParam { int64_t axis = 0; int64_t batchDims = 0; uint8_t rsv[16] = {0}; };  | 
参数列表
成员名称  | 
类型  | 
默认值  | 
取值范围  | 
是否必选  | 
描述  | 
|---|---|---|---|---|---|
axis  | 
int64_t  | 
0  | 
>=0且<=n  | 
是  | 
指定要收集切片的哪个轴。默认值为0。n为输入x的维度最大索引。  | 
batchDims  | 
int64_t  | 
0  | 
>=0且<=min(axis,m)  | 
是  | 
代表批处理的维度数。表示可以从每轮批处理的元素中分别取出满足要求的切片数据。 例如,如果batchDims=1,则代表在x的第(axis - batchDims)轴上有一个外循环,见示例2。 m为输入indices的维度最大索引。  | 
rsv[16]  | 
uint8_t  | 
{0}  | 
[0]  | 
否  | 
预留参数。  | 
输入
输出
参数  | 
维度  | 
数据类型  | 
格式  | 
描述  | 
|---|---|---|---|---|
output  | 
[dim_x_0,dim_x_1,...,dim_x_n]  | 
float16/float/bfloat16/int32/uint32/int64  | 
ND  | 
输出tensor。数据类型与x保持一致。 
  | 
约束说明
因output的维度数小于等于8,故(n+1)+(m+1)-1-batchdims <= 8。
接口调用示例
- 示例1:
axis = 1; batchDims = 0; 输入tensor为: x= [[1,2,3], [4,5,6], [7,8,9]] indices tensor为: indices=[2,0] 根据indices tensor的值,在axis轴获取x数据切片,output tensor为: output=[[3, 1], [6, 4], [9, 7]] - 示例2:
axis= 1; batchDims = 1; 输入tensor为: x= [[1,2,3], [4,5,6], [7,8,9]] indices tensor为: indices=[[1], [2], [0]] 因为batch_dims=1,则代表在第一个轴上(即轴0)进行批处理。在轴0上,将x[i]和indices[i]进行一一对应的gather处理,根据indices的值,在axis轴获取x的数据切片, 其中i为batch轴的坐标,output tensor为: output= [[2,], [6,], [7,]] 上面等价于: def gather_fun(x, indices, axis): batch_dims=1 res= [] # 进行外循环 for p,i in zip(x, indices): r = tf.gather(p, i, axis=axis-batch_dims) res.append(r) return tf.stack(res)