NPUEstimator构造函数

函数原型

def __init__(self,

model_fn=None,

model_dir=None,

config=None,

params=None,

job_start_file='',

warm_start_from=None

)

功能说明

NPUEstimator类的构造函数,NPUEstimator类继承了Estimator类,可以调用基类的原生接口。

参数说明

参数名

输入/输出

描述

model_fn

输入

模型function定义,该function返回NPUEstimatorSpec类对象。

关于NPUEstimatorSpec类的构造函数,请参见NPUEstimatorSpec构造函数

model_dir

输入

保存模型路径, 用于保存或恢复模型文件。默认为None。

如果NPURunConfig和NPUEstimator配置的model_dir不同,系统报错。

如果NPURunConfig和NPUEstimator仅一个接口配置model_dir,以配置的路径为准。

如果NPURunConfig和NPUEstimator均未配置model_dir,则系统在当前脚本执行路径创建一个model_dir_xxxxxxxxxx目录保存模型文件。

config

输入

NPURunConfig类对象。

关于NPURunConfig类的构造函数,请参见NPURunConfig构造函数

params

输入

传入model_fn的参数,为字典类型,键为传入参数的名字,值为基本的python类型值

job_start_file

输入

CSA job 启动文件路径,上云场景无需配置。

warm_start_from

输入

指定checkpoint路径,会导入该checkpoint开始训练。

返回值

返回NPUEstimator类对象。