昇腾社区首页
中文
注册

main_train_process

接口功能

trainer函数入口。除了用于存储中间状态的工作目录参数是s3路径外,其他的输入都是从本地文件系统输入。

接口引用路径

ockics.entry.main_train_process

接口格式

函数:main_train_process(train_para)

输入参数

参数

类型

是否必选

说明

取值要求

train_para

ParaTrain

必选

trainer组件的所有输入数据。

函数参数,元组类型,具体元素详见ParaTrain

返回值

使用样例

import argparse
from collections import namedtuple
ParaTrain = namedtuple("ParaTrain",
                       ["base_model", "train_datasets", "train_dataset_images", "eval_datasets", "eval_dataset_images",
                        "train_config", "dataset_fmt", "standalone", "work_url", "output_train_artifact",
                        "output_infer_artifact", "s3_endpoint", "s3_ak", "s3_sk", "s3_secure", "s3_certcheck"])

def main_train(s3_endpoint=None, s3_ak=None, s3_sk=None, s3_secure=None, s3_certcheck=None):
    argss = _parse_args()
    para = ParaTrain(argss.baseModel, argss.trainDatasets, argss.train_dataset_images, argss.evalDatasets,
                     argss.eval_dataset_images, argss.trainConfig,
                     argss.dataset_fmt, argss.standalone, argss.workurl,
                     argss.outputTrainArtifact, argss.outputInferArtifact,
                     s3_endpoint, s3_ak, s3_sk, s3_secure, s3_certcheck)
    main_train_process(para)