后处理计算功能。实现输入tensor在最后一维上降序排列,并保留最大的num个元素,输出排序后的tensor及各元素对应的索引。
struct SortParam { SVector<int32_t> num; };
成员名称 |
类型 |
默认值 |
描述 |
---|---|---|---|
num |
SVector<int32_t> |
- |
排序后保留的最大的元素的数量。 num是一个仅含有一个值的SVector,该值需大于0且小于等于输入x最后一维的大小。 |
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
x |
[dim_0,dim_1,...,dim_n] |
float16/bf16 |
ND |
最后一维应至少有num个元素。 |
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
output |
[dim_0,dim_1,...,num] |
float16/bf16 |
ND |
最后一维排序后,最大的num个元素。 |
indices |
[dim_0,dim_1,...,num] |
int32 |
ND |
最大的num个元素对应的原索引。 |
num是一个仅含有一个值的SVector,该值需大于0且小于等于输入x最后一维的大小。
输入:
num = [1] x = [[3.0, 4.0], [5.0, 6.0]]
输出:
output = [[4.0], [6.0]] indices = [[1.0], [1.0]]