GatherOperation
产品支持情况
硬件型号 |
是否支持 |
特殊说明 |
---|---|---|
√ |
- |
|
√ |
- |
|
√ |
- |
|
√ |
- |
|
√ |
输入x数据类型只支持float16。 |
功能说明
从输入张量中根据索引收集切片,并将这些切片组合成一个新的张量。
图1 GatherOperation算子上下文


定义
1 2 3 4 5 | struct GatherParam { int64_t axis = 0; int64_t batchDims = 0; uint8_t rsv[16] = {0}; }; |
参数列表
成员名称 |
类型 |
默认值 |
取值范围 |
是否必选 |
描述 |
---|---|---|---|---|---|
axis |
int64_t |
0 |
>=0且<=n |
是 |
指定要收集切片的哪个轴。默认值为0。n为输入x的维度最大索引。 |
batchDims |
int64_t |
0 |
>=0且<=min(axis,m) |
是 |
代表批处理的维度数。表示可以从每轮批处理的元素中分别取出满足要求的切片数据。 例如,如果batchDims=1,则代表在x的第(axis - batchDims)轴上有一个外循环,见示例2。 m为输入indices的维度最大索引。 |
rsv[16] |
uint8_t |
{0} |
[0] |
否 |
预留参数。 |
输入
输出
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
output |
[dim_x_0,dim_x_1,...,dim_x_n] |
float16/float/bfloat16/int32/uint32/int64 |
ND |
输出tensor。数据类型与x保持一致。
|
约束说明
因output的维度数小于等于8,故(n+1)+(m+1)-1-batchdims <= 8。
接口调用示例
- 示例1:
axis = 1; batchDims = 0; 输入tensor为: x= [[1,2,3], [4,5,6], [7,8,9]] indices tensor为: indices=[2,0] 根据indices tensor的值,在axis轴获取x数据切片,output tensor为: output=[[3, 1], [6, 4], [9, 7]]
- 示例2:
axis= 1; batchDims = 1; 输入tensor为: x= [[1,2,3], [4,5,6], [7,8,9]] indices tensor为: indices=[[1], [2], [0]] 因为batch_dims=1,则代表在第一个轴上(即轴0)进行批处理。在轴0上,将x[i]和indices[i]进行一一对应的gather处理,根据indices的值,在axis轴获取x的数据切片, 其中i为batch轴的坐标,output tensor为: output= [[2,], [6,], [7,]] 上面等价于: def gather_fun(x, indices, axis): batch_dims=1 res= [] # 进行外循环 for p,i in zip(x, indices): r = tf.gather(p, i, axis=axis-batch_dims) res.append(r) return tf.stack(res)