BaseEstimator类
train接口功能
需要用户实现的训练接口。
接口引用路径
ockics.modules.estimator.BaseEstimator
train接口格式
函数:train(self, init_model, train_data, eval_data, train_config_yaml) -> (str, str)
train接口输入参数
参数 |
类型 |
是否必选 |
说明 |
取值要求 |
|---|---|---|---|---|
init_model |
str |
必选 |
传递给训练脚本的模型文件路径。 |
- |
train_data |
str |
必选 |
传递给训练脚本的训练数据集目录,目录下按要求存放着数据集,以COCO数据集格式为例,此目录下存在两个子目录:annotations字母下存放标注json文件;images子目录下存放所有图片jpg文件。 |
- |
eval_data |
str |
必选 |
传递给训练脚本的测评数据集目录,目录内数据格式与train_data相同。 |
- |
train_config_yaml |
str |
必选 |
传递给训练脚本的训练配置文件路径。 |
- |
train接口返回值
- train_model:训练脚本产出的模型路径。
- infer_model:训练脚本导出的推理模型路径,一般导出ONNX格式的模型。
train接口使用样例
from ockics.common import ClassFactory, ClassType
from ockics.entry import main_train
from ockics.modules.estimator import BaseEstimator
@ClassFactory.register(ClassType.ESTIMATOR)
class Estimator(BaseEstimator):
'''
my own train estimator
'''
def train(self, init_model, train_data, eval_data, train_config_yaml):
print(init_model)
print(train_data)
print(eval_data)
print(train_config_yaml)
return "./samples/yolov5/yolov5x6_base.pt", "./samples/yolov5/yolov5x6.onnx"
if __name__ == "__main__":
eval接口功能
需要用户实现的模型测评接口。
eval接口格式
函数:eval(self, model, eval_data, eval_config_yaml) -> (float, float, float)
eval接口输入参数
参数 |
类型 |
是否必选 |
说明 |
取值要求 |
|---|---|---|---|---|
model |
str |
- |
待测评的模型路径名称,此处会将训练模型传递给脚本进行测评。 |
- |
eval_data |
str |
- |
传递给测评脚本的测评数据集目录,目录中按要求存放着数据集,以COCO数据集格式为例,此目录下存在两个子目录:annotations字母下存放标注json文件;images子目录下存放所有图片jpg文件。 |
- |
eval_config_yaml |
str |
- |
传递给测评脚本的测评配置文件路径。 |
- |
eval接口返回值
- 全类平均精度(mAP):建议返回mAP@.5值,用户也可以返回其他指标。
- 精准率(precision):测评精度。
- 召回率(recall):测评召回率。
eval接口使用样例
from ockics.common.class_factory import ClassFactory, ClassType
from ockics.entry.evaluate import main_eval
from ockics.modules.estimator.estimator import BaseEstimator
@ClassFactory.register(ClassType.ESTIMATOR)
class Estimator(BaseEstimator):
'''
my own train estimator
'''
def eval(self, model, eval_data, eval_config_yaml) -> float:
'''
self defined eval implementation
'''
super().eval(model, eval_data, eval_config_yaml)
print(model)
print(eval_data)
print(eval_config_yaml)
return 0.565, 0.8, 0.9
if __name__ == "__main__":
main_eval()
父主题: register接口