def __init__(self,
model_fn=None,
model_dir=None,
config=None,
params=None,
job_start_file='',
warm_start_from=None
)
NPUEstimator类的构造函数,NPUEstimator类继承了Estimator类,可以调用基类的原生接口。
参数名 |
输入/输出 |
描述 |
---|---|---|
model_fn |
输入 |
模型function定义,该function返回NPUEstimatorSpec类对象。 关于NPUEstimatorSpec类的构造函数,请参见NPUEstimatorSpec构造函数。 |
model_dir |
输入 |
保存模型路径, 用于保存或恢复模型文件。默认为None。 如果NPURunConfig和NPUEstimator配置的model_dir不同,系统报错。 如果NPURunConfig和NPUEstimator仅一个接口配置model_dir,以配置的路径为准。 如果NPURunConfig和NPUEstimator均未配置model_dir,则系统在当前脚本执行路径创建一个model_dir_xxxxxxxxxx目录保存模型文件。 |
config |
输入 |
NPURunConfig类对象。 关于NPURunConfig类的构造函数,请参见NPURunConfig构造函数。 |
params |
输入 |
传入model_fn的参数,为字典类型,键为传入参数的名字,值为基本的python类型值 |
job_start_file |
输入 |
CSA job 启动文件路径,上云场景无需配置。 |
warm_start_from |
输入 |
指定checkpoint路径,会导入该checkpoint开始训练。 |
返回NPUEstimator类对象。