BaseMiner类
__call__接口功能
需要用户实现的难例筛选算法接口。
接口引用路径
ockics.modules.estimator.BaseMiner
__call__接口格式
函数:__call__(self, infer_data=None, infer_feature=None, infer_result=None)
__call__接口输入参数
参数 |
类型 |
是否必选 |
说明 |
取值要求 |
---|---|---|---|---|
infer_data |
str |
必选 |
待判定的样本图片路径。 |
- |
infer_feature |
list |
必选 |
待判定的样本经过模型推理后的特征值,以一个list[numpy.ndarray]存放的所有特征值。 |
- |
infer_result |
list |
必选 |
list[dict]形式的推理结果,每个dict的格式如下: { 'pred_class': 3, 'pred_box': [x1, y1, x2, y2], 'confidence_score': 0.12, 'pred_core': [0.0078, 0.9301, 0.0035, 0.0026, ....], 'width': 3000, 'height': 2834 } |
- |
__call__接口返回值
bool类型,代表是否难例。
- True:本样本是难例
- False:本样本非难例
__call__接口使用样例
import random from ockics.common import ClassFactory, ClassType from ockics.entry import main_hardmining from ockics.modules.estimator import BaseMiner @ClassFactory.register(ClassType.HEM, 'all') class HCDFilter(BaseMiner): def __call__(self, infer_data=None, infer_feature=None, infer_result=None): return True @ClassFactory.register(ClassType.HEM, 'random') class RandomFilter(BaseMiner): def __call__(self, infer_data=None, infer_feature=None, infer_result=None): result = random.randint(0, 9) return True if result == 1 else False @ClassFactory.register(ClassType.HEM, 'error') class ErrorFilter(BaseMiner): def __call__(self, infer_data=None, infer_feature=None, infer_result=None): result = random.randint(0, 1) if result == 0: raise ValueError("inject hardmining algorithm error") return True if __name__ == "__main__": main_hardmining()
父主题: register接口