昇腾社区首页
中文
注册

模型评估启动脚本适配样例(argparse)

import argparse
import logging
import json
import stat
import os

DEFAULT_FLAGS = os.O_RDWR | os.O_CREAT
DEFAULT_MODES = stat.S_IWUSR | stat.S_IRUSR

# 模拟模型源码逻辑
def model_eval(args):	
    logging.info('%s', args.data_path)	# 模型代码中可接收到对应参数
    logging.info('%s', args.output_path)
    logging.info('%s', args.model_config)
    logging.info('%s', args.learning_rate)
    eval_result = json.dumps({'评估结果': '待输出评估结果'})
    eval_result_path = os.path.join(args.output_path, 'eval_result.json')	# 此处必须以eval_result.json文件存储至指定输出路径,否则评估任务无结果返回,视为失败
    with os.fdopen(os.open(eval_result_path, DEFAULT_FLAGS, DEFAULT_MODES), 'w') as file:
        json.dump(eval_result, file)

# 用于argparse参数类型转换。
def str2bool(param):
    return param.lower() == 'true'

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # 必选适配项
    parser.add_argument('-dp', '--data_path', type=str, required=True)	# 数据集路径
    parser.add_argument('-op', '--output_path', type=str, required=True)	# 输出路径
    parser.add_argument('-cp', '--ckpt_path', type=str, required=True)	# 模型结果路径

    # 模型配置文件中包含的配置项
    parser.add_argument('-ep', '--eval_param', type=float, required=False)
    parser.add_argument('-bp', '--bool_param', type=str2bool, required=False)  # argparse不支持传入bool类型参数,需要在type中转换参数类型

    # 模型脚本接收的参数集合
    args = parser.parse_args()
    model_eval(args)

对应模型配置文件的内容:

params:
 eval_param: 'your settings'
 bool_param: false