BaseEstimator类
train接口功能
需要用户实现的训练接口。
接口引用路径
ockics.modules.estimator.BaseEstimator
train接口格式
函数:train(self, init_model, train_data, eval_data, train_config_yaml) -> (str, str)
train接口输入参数
参数 |
类型 |
是否必选 |
说明 |
取值要求 |
---|---|---|---|---|
init_model |
str |
必选 |
传递给训练脚本的模型文件路径。 |
- |
train_data |
str |
必选 |
传递给训练脚本的训练数据集目录,目录下按要求存放着数据集,以COCO数据集格式为例,此目录下存在两个子目录:annotations字母下存放标注json文件;images子目录下存放所有图片jpg文件。 |
- |
eval_data |
str |
必选 |
传递给训练脚本的测评数据集目录,目录内数据格式与train_data相同。 |
- |
train_config_yaml |
str |
必选 |
传递给训练脚本的训练配置文件路径。 |
- |
train接口返回值
- train_model:训练脚本产出的模型路径。
- infer_model:训练脚本导出的推理模型路径,一般导出ONNX格式的模型。
train接口使用样例
from ockics.common import ClassFactory, ClassType from ockics.entry import main_train from ockics.modules.estimator import BaseEstimator @ClassFactory.register(ClassType.ESTIMATOR) class Estimator(BaseEstimator): ''' my own train estimator ''' def train(self, init_model, train_data, eval_data, train_config_yaml): print(init_model) print(train_data) print(eval_data) print(train_config_yaml) return "./samples/yolov5/yolov5x6_base.pt", "./samples/yolov5/yolov5x6.onnx" if __name__ == "__main__":
eval接口功能
需要用户实现的模型测评接口。
eval接口格式
函数:eval(self, model, eval_data, eval_config_yaml) -> (float, float, float)
eval接口输入参数
参数 |
类型 |
是否必选 |
说明 |
取值要求 |
---|---|---|---|---|
model |
str |
- |
待测评的模型路径名称,此处会将训练模型传递给脚本进行测评。 |
- |
eval_data |
str |
- |
传递给测评脚本的测评数据集目录,目录中按要求存放着数据集,以COCO数据集格式为例,此目录下存在两个子目录:annotations字母下存放标注json文件;images子目录下存放所有图片jpg文件。 |
- |
eval_config_yaml |
str |
- |
传递给测评脚本的测评配置文件路径。 |
- |
eval接口返回值
- 全类平均精度(mAP):建议返回mAP@.5值,用户也可以返回其他指标。
- 精准率(precision):测评精度。
- 召回率(recall):测评召回率。
eval接口使用样例
from ockics.common.class_factory import ClassFactory, ClassType from ockics.entry.evaluate import main_eval from ockics.modules.estimator.estimator import BaseEstimator @ClassFactory.register(ClassType.ESTIMATOR) class Estimator(BaseEstimator): ''' my own train estimator ''' def eval(self, model, eval_data, eval_config_yaml) -> float: ''' self defined eval implementation ''' super().eval(model, eval_data, eval_config_yaml) print(model) print(eval_data) print(eval_config_yaml) return 0.565, 0.8, 0.9 if __name__ == "__main__": main_eval()
父主题: register接口