工厂类型类
ClassType接口功能
工厂类型类,定义了ICS-SDK中需要用到的几种类型。
接口引用路径
ockics.common.ClassType
ClassType接口格式
类成员变量:
class ClassType:
"""Const class saved defined class type."""
GENERAL = 'general'
HEM = 'hard_example_mining'
ESTIMATOR = 'estimator'
MODEL_RELEASE = 'model_release'
CALLBACK = 'post_process_callback'
ClassType接口输入参数
参数 |
类型 |
是否必选 |
说明 |
取值要求 |
|---|---|---|---|---|
GENERAL |
str |
- |
通用类别,暂未使用。 |
- |
HEM |
str |
- |
难例筛选算法,注册难例筛选算法接口。 |
- |
ESTIMATOR |
str |
- |
训练、测评脚本对象,注册模型训练和测评脚本接口。 |
- |
MODEL_RELEASE |
str |
- |
模型发布对象,注册模型发布的接口。 |
- |
使用样例
from ockics.common import ClassFactory, ClassType
from ockics.entry import main_train
from ockics.modules.estimator import BaseEstimator
@ClassFactory.register(ClassType.ESTIMATOR)
class Estimator(BaseEstimator):
'''
my own train estimator
'''
def train(self, init_model, train_data, eval_data, train_config_yaml):
print(init_model)
print(train_data)
print(eval_data)
print(train_config_yaml)
return "./samples/yolov5/yolov5x6_base.pt", "./samples/yolov5/yolov5x6.onnx"
if __name__ == "__main__":
main_train()
父主题: register接口