昇腾社区首页
中文
注册

SortOperation

产品支持情况

硬件型号

是否支持

特殊说明

Atlas A3 推理系列产品/Atlas A3 训练系列产品

-

Atlas A2 训练系列产品/Atlas 800I A2 推理产品

-

Atlas 训练系列产品

不支持float数据类型。

Atlas 推理系列产品

不支持float、bf16数据类型。

Atlas 200I/500 A2 推理产品

不支持float数据类型。

功能说明

后处理计算功能。实现输入tensor在最后一维上降序排列,并保留最大的num个元素,输出排序后的tensor及各元素对应的索引。

算子上下文

图1 SortOperation算子上下文

定义

1
2
3
4
struct SortParam {
    SVector<int32_t> num;
    uint8_t rsv[8] = {0};
};

参数列表

成员名称

类型

默认值

描述

num

SVector<int32_t>

-

排序后保留的最大的元素的数量。

num是一个仅含有一个值的SVector,该值需大于0且小于等于输入x最后一维的大小。

rsv[8]

uint8_t

{0}

预留参数。

输入

参数

维度

数据类型

格式

描述

x

[dim_0,dim_1,...,dim_n]

float16/bf16/float

ND

最后一维应至少有num个元素。

输出

参数

维度

数据类型

格式

描述

output

[dim_0,dim_1,...,num]

float16/bf16/float

ND

最后一维排序后,最大的num个元素。

indices

[dim_0,dim_1,...,num]

int32

ND

最大的num个元素对应的原索引。

约束说明

  • num是一个仅含有一个值的SVector,该值需大于0且小于等于输入x最后一维的大小。
  • Atlas 推理系列产品不支持num小于16、非尾轴的维度大于580000000的场景。

接口调用示例

输入:

num = [1]
x = [[3.0, 4.0],
       [5.0, 6.0]]

输出:

output = [[4.0],
                 [6.0]]
indices = [[1.0],
                 [1.0]]