昇腾社区首页
中文
注册

BaseEstimator类

train接口功能

需要用户实现的训练接口。

接口引用路径

ockics.modules.estimator.BaseEstimator

train接口格式

函数:train(self, init_model, train_data, eval_data, train_config_yaml) -> (str, str)

train接口输入参数

参数

类型

是否必选

说明

取值要求

init_model

str

必选

传递给训练脚本的模型文件路径。

-

train_data

str

必选

传递给训练脚本的训练数据集目录,目录下按要求存放着数据集,以COCO数据集格式为例,此目录下存在两个子目录:annotations字母下存放标注json文件;images子目录下存放所有图片jpg文件。

-

eval_data

str

必选

传递给训练脚本的测评数据集目录,目录内数据格式与train_data相同。

-

train_config_yaml

str

必选

传递给训练脚本的训练配置文件路径。

-

train接口返回值

  • train_model:训练脚本产出的模型路径。
  • infer_model:训练脚本导出的推理模型路径,一般导出ONNX格式的模型。

train接口使用样例

from ockics.common import ClassFactory, ClassType
from ockics.entry import main_train
from ockics.modules.estimator import BaseEstimator
@ClassFactory.register(ClassType.ESTIMATOR)
class Estimator(BaseEstimator):
    '''
    my own train estimator
    '''
    def train(self, init_model, train_data, eval_data, train_config_yaml):
        print(init_model)
        print(train_data)
        print(eval_data)
        print(train_config_yaml)
        return "./samples/yolov5/yolov5x6_base.pt", "./samples/yolov5/yolov5x6.onnx"
if __name__ == "__main__":
    

eval接口功能

需要用户实现的模型测评接口。

eval接口格式

函数:eval(self, model, eval_data, eval_config_yaml) -> (float, float, float)

eval接口输入参数

参数

类型

是否必选

说明

取值要求

model

str

-

待测评的模型路径名称,此处会将训练模型传递给脚本进行测评。

-

eval_data

str

-

传递给测评脚本的测评数据集目录,目录中按要求存放着数据集,以COCO数据集格式为例,此目录下存在两个子目录:annotations字母下存放标注json文件;images子目录下存放所有图片jpg文件。

-

eval_config_yaml

str

-

传递给测评脚本的测评配置文件路径。

-

eval接口返回值

  • 全类平均精度(mAP):建议返回mAP@.5值,用户也可以返回其他指标。
  • 精准率(precision):测评精度。
  • 召回率(recall):测评召回率。

eval接口使用样例

from ockics.common.class_factory import ClassFactory, ClassType
from ockics.entry.evaluate import main_eval
from ockics.modules.estimator.estimator import BaseEstimator
@ClassFactory.register(ClassType.ESTIMATOR)
class Estimator(BaseEstimator):
    '''
    my own train estimator
    '''
    def eval(self, model, eval_data, eval_config_yaml) -> float:
        '''
        self defined eval implementation
        '''
        super().eval(model, eval_data, eval_config_yaml)
        print(model)
        print(eval_data)
        print(eval_config_yaml)
        return 0.565, 0.8, 0.9
if __name__ == "__main__":
    main_eval()