gather
功能说明
获取输入Tensor的特定维度中指定索引的切片数据。
例如:
- 输入Tensor params为:params = [[1,2,3],[4,5,6],[7,8,9]]
- 输入Tensor indices为:indices = [2,0]
- 输入axis = 0
调用此接口,则代表需要在params的第0维中,分别取索引为2和索引为0的数,结果如下所示:
gather_tensor = [[7,8,9], [1,2,3]]
函数原型
gather(params, indices, axis=None, batch_dims=0, impl_mode="support_out_of_bound_index")
参数说明
- params:待切片的数据。输入Tensor,tvm.tensor类型。
支持的数据类型:float16, float32, int8, uint8, int16, uint16, int32, uint32, int64, uint64, bool。
- indices:待取出数据的位置索引。输入Tensor,tvm.tensor类型,支持的数据类型:int32, int64。
- axis:指定切片数据所在的维度。取值范围:[-p, p -1],且要大于等于batch_dims,其中p是params的维数,int类型。
- batch_dims:可选参数,代表批处理的维度数。表示可以从每轮批处理的元素中分别取出满足要求的切片数据。例如,如果batch_dims=1,则代表在params的第一个轴上有一个外循环indices,详细功能可参见•示例:batch_dims为1。
batch_dims的取值范围为[-i, i],其中i是indices的维数。
数据类型为int,且需要满足:batch_dims <=axis。
- impl_mode:可选参数,算子处理时选择的模式,默认值为"support_out_of_bound_index",亦可设置为None。"support_out_of_bound_index"模式下,当索引值超范围(正常索引的范围为[-params.shape[axis], params.shape[axis]))时,对应的输出置零。其他模式下,当索引值超范围时,会报AI Core Error。
返回值
gather_tensor:执行gather之后的Tensor,tvm.tensor类型
约束说明
- 若配置了batch_dims,即batch_dims≠0,则params与indices的batch轴的维度大小要保持一致。例如,若batch_dims=1,表示要在第一个轴上做批处理,则params与indices的轴0的维度大小要相同。
- 此接口暂不支持与其他TBE DSL计算接口混合使用。
支持的型号
Atlas 训练系列产品
Atlas 推理系列产品(Ascend 310P处理器)
Atlas 200/500 A2推理产品
Atlas A2训练系列产品/Atlas 800I A2推理产品
调用示例
- 示例:batch_dims为0
params = [[1,2,3], [4,5,6], [7,8,9]]
输入Tensor indices为:
indices = [2,0]
输入axis = 1, batch_dims=0
调用gather接口,根据indices tensor的值,在axis轴获取params数据切片,代码示例如下所示:from tbe import tvm from tbe import dsl params = tvm.placeholder((3,3), dtype=dtype, name="params") indices = tvm.placeholder((2,), dtype=dtype, name="indices") set_valued_tensor = dsl.gather(params,indices,0, 0)
以上代码的功能为获取params 0轴上index为2和0的数据切片,输出结果如下所示:
gather_tensor = [[3, 1], [6, 4], [9, 7]]
- 示例:batch_dims为1
params = [[1,2,3], [4,5,6], [7,8,9]]
输入Tensor indices为:
indices = [[1], [2], [0]]
输入axis = 1, batch_dims=1
调用gather接口,代码示例如下所示:from tbe import tvm from tbe import dsl params = tvm.placeholder((3, 3), dtype=dtype, name="params") indices = tvm.placeholder((3, 1), dtype=dtype, name="indices") set_valued_tensor = dsl.gather(params, indices, 1, 1)
因为batch_dims=1,则代表在第一个轴上(即轴0)进行批处理。在轴0上,将params[i]和indices[i]进行一一对应的gather处理,根据indices的值,在axis轴获取params数据切片, 其中i为batch轴的坐标。
输出结果如下所示:
gather_tensor = [[2,], [6,], [7,]]