昇腾社区首页
中文
注册

beta)torch_npu.contrib.function.npu_batched_multiclass_nms

接口原型

npu_batched_multiclass_nms(multi_bboxes,multi_scores,score_thr=0.05,nms_thr=0.45,max_num=50,score_factors=None):

功能描述

使用NPU API的批量多类bbox NMS。

参数说明

  • multi_bboxes (Tensor) - shape(bs, n, #class, 4)或(bs, n, 4)。
  • multi_scores (Tensor) - shape(bs, n, #class+1),其中最后一列包含background class分数,可忽略。在NPU上,为保持语义畅通,我们将统一维度。
  • score_thr (Float,默认值为0.05) - bbox阈值,分数低于它的bbox将不被考虑。
  • nms_thr (Float,默认值为0.45) - NMS IoU阈值。最初的实现是传递{"iouthreshold": 0.45}字典,这里做了简化。
  • max_num (Int,默认值为50) - 如果NMS后的bbox数超过max_num值,则只保留最大max_num;如果NMS后的bbox数小于max_num值,则输出将零填充到max_num值。在NPU上需提前申请内存,因此目前不能将max_num值设置为-1。
  • score_factors (Tensor,默认值为None) - NMS应用前用来乘分数的因子。

输出说明

Tuple - (bboxes, labels),shape为(bs, k, 5)和(bs, k, 1)的张量。标签以0为基础。

约束说明

在动态shape条件下,由于NPU op的限制,最多支持20个类别(nmsed_classes)和10000个框(nmsed_boxes)。

调用示例

调用方式示例:
from torch_npu.contrib.function import npu_batched_multiclass_nms
det_bboxes, det_labels = npu_batched_multiclass_nms(boxes, scores, score_thr=0.3, nms_thr=0.5, max_num=3)
使用示例:
  >>> boxes = torch.randint(1, 255, size=(4, 200, 80, 4))
  >>> scores = torch.randn(4, 200, 81)
  >>> boxes = boxes.npu().half()
  >>> scores = scores.npu().half()
  >>> det_bboxes, det_labels = npu_batched_multiclass_nms(boxes, scores, score_thr=0.3, nms_thr=0.5, max_num=3)
  >>> expedt_det_bboxes = torch.tensor([[[221.8750,  60.0000, 183.0000,  22.0000,   3.8867], [167.0000, 250.0000, 136.0000, 144.0000,   3.6445], [ 45.9688, 147.0000,  67.0000, 241.8750,   3.4844]], 
        [[  5.0000, 178.0000, 243.8750, 138.0000,   3.7344], [238.0000, 132.0000,  47.0000,  84.0000,   3.6836], [ 32.0000, 110.0000, 131.0000,  73.0000,   3.6309]], 
        [[111.9375, 120.9375,  54.0000, 231.0000,   3.9219], [147.0000, 162.0000,  78.0000,   1.0010,   3.9219], [157.0000, 118.0000,  57.0000, 115.0000,   3.6523]], 
        [[ 80.0000, 126.9375,  54.0000, 246.8750,   3.7344], [ 31.0000, 253.8750,  19.0000, 138.0000,   3.6328], [ 54.0000, 253.8750,  78.0000,  75.0000,   3.5586]]],dtype=torch.float16)