昇腾社区首页
中文
注册

BaseMiner类

__call__接口功能

需要用户实现的难例筛选算法接口。

接口引用路径

ockics.modules.estimator.BaseMiner

__call__接口格式

函数:__call__(self, infer_data=None, infer_feature=None, infer_result=None)

__call__接口输入参数

参数

类型

是否必选

说明

取值要求

infer_data

str

必选

待判定的样本图片路径。

-

infer_feature

list

必选

待判定的样本经过模型推理后的特征值,以一个list[numpy.ndarray]存放的所有特征值。

-

infer_result

list

必选

list[dict]形式的推理结果,每个dict的格式如下:

{
                  'pred_class': 3,
                  'pred_box': [x1, y1, x2, y2],
                  'confidence_score': 0.12,
                  'pred_core': [0.0078, 0.9301, 0.0035, 0.0026, ....],
                  'width': 3000,
                  'height': 2834
}

-

__call__接口返回值

bool类型,代表是否难例。

  • True:本样本是难例
  • False:本样本非难例

__call__接口使用样例

import random
from ockics.common import ClassFactory, ClassType
from ockics.entry import main_hardmining
from ockics.modules.estimator import BaseMiner
@ClassFactory.register(ClassType.HEM, 'all')
class HCDFilter(BaseMiner):
    def __call__(self, infer_data=None, infer_feature=None, infer_result=None):
        
        return True
@ClassFactory.register(ClassType.HEM, 'random')
class RandomFilter(BaseMiner):
    def __call__(self, infer_data=None, infer_feature=None, infer_result=None):
        result = random.randint(0, 9)
        
        return True if result == 1 else False
@ClassFactory.register(ClassType.HEM, 'error')
class ErrorFilter(BaseMiner):
    def __call__(self, infer_data=None, infer_feature=None, infer_result=None):
        result = random.randint(0, 1)
        if result == 0:
            raise ValueError("inject hardmining algorithm error")
        
        return True
if __name__ == "__main__":
    main_hardmining()