昇腾社区首页
中文
注册

工厂类型类

ClassType接口功能

工厂类型类,定义了ICS-SDK中需要用到的几种类型。

接口引用路径

ockics.common.ClassType

ClassType接口格式

类成员变量:

class ClassType:
    """Const class saved defined class type."""
    GENERAL = 'general'
    HEM = 'hard_example_mining'
    ESTIMATOR = 'estimator'
    MODEL_RELEASE = 'model_release'
    CALLBACK = 'post_process_callback'

ClassType接口输入参数

参数

类型

是否必选

说明

取值要求

GENERAL

str

-

通用类别,暂未使用。

-

HEM

str

-

难例筛选算法,注册难例筛选算法接口。

-

ESTIMATOR

str

-

训练、测评脚本对象,注册模型训练和测评脚本接口。

-

MODEL_RELEASE

str

-

模型发布对象,注册模型发布的接口。

-

使用样例

from ockics.common import ClassFactory, ClassType
from ockics.entry import main_train
from ockics.modules.estimator import BaseEstimator
@ClassFactory.register(ClassType.ESTIMATOR)
class Estimator(BaseEstimator):
    '''
    my own train estimator
    '''
    def train(self, init_model, train_data, eval_data, train_config_yaml):
        print(init_model)
        print(train_data)
        print(eval_data)
        print(train_config_yaml)
        return "./samples/yolov5/yolov5x6_base.pt", "./samples/yolov5/yolov5x6.onnx"
if __name__ == "__main__":
    main_train()