BaseMiner类
__call__接口功能
需要用户实现的难例筛选算法接口。
接口引用路径
ockics.modules.estimator.BaseMiner
__call__接口格式
函数:__call__(self, infer_data=None, infer_feature=None, infer_result=None)
__call__接口输入参数
参数 |
类型 |
是否必选 |
说明 |
取值要求 |
|---|---|---|---|---|
infer_data |
str |
必选 |
待判定的样本图片路径。 |
- |
infer_feature |
list |
必选 |
待判定的样本经过模型推理后的特征值,以一个list[numpy.ndarray]存放的所有特征值。 |
- |
infer_result |
list |
必选 |
list[dict]形式的推理结果,每个dict的格式如下: {
'pred_class': 3,
'pred_box': [x1, y1, x2, y2],
'confidence_score': 0.12,
'pred_core': [0.0078, 0.9301, 0.0035, 0.0026, ....],
'width': 3000,
'height': 2834
}
|
- |
__call__接口返回值
bool类型,代表是否难例。
- True:本样本是难例
- False:本样本非难例
__call__接口使用样例
import random
from ockics.common import ClassFactory, ClassType
from ockics.entry import main_hardmining
from ockics.modules.estimator import BaseMiner
@ClassFactory.register(ClassType.HEM, 'all')
class HCDFilter(BaseMiner):
def __call__(self, infer_data=None, infer_feature=None, infer_result=None):
return True
@ClassFactory.register(ClassType.HEM, 'random')
class RandomFilter(BaseMiner):
def __call__(self, infer_data=None, infer_feature=None, infer_result=None):
result = random.randint(0, 9)
return True if result == 1 else False
@ClassFactory.register(ClassType.HEM, 'error')
class ErrorFilter(BaseMiner):
def __call__(self, infer_data=None, infer_feature=None, infer_result=None):
result = random.randint(0, 1)
if result == 0:
raise ValueError("inject hardmining algorithm error")
return True
if __name__ == "__main__":
main_hardmining()
父主题: register接口