NPUEstimator Constructor
Description
Constructor of the NPUEstimator class. The NPUEstimator class inherits the Estimator class of TensorFlow and can call the native APIs of the base class to train and evaluate TensorFlow models.
Prototype
def __init__(self, model_fn=None, model_dir=None, config=None, params=None, job_start_file='', warm_start_from=None )
Options
Option |
Input/Output |
Description |
|---|---|---|
model_fn |
Input |
Model function definition. This function returns an object of the NPUEstimatorSpec class. For details about the constructor of the NPUEstimatorSpec class, see NPUEstimatorSpec Constructor. |
model_dir |
Input |
Model directory, which is used to save or restore model files. Defaults to None. If model_dir set in NPURunConfig is different from that in NPUEstimator, an error is reported. If either NPURunConfig or NPUEstimator is configured with model_dir, the configured path applies. If neither NPURunConfig nor NPUEstimator is configured with model_dir, a model_dir_xxxxxxxxxx directory is created in the current script execution path to save the model file. |
config |
Input |
Object of the NPURunConfig class. For details about the constructor of the NPURunConfig class, see NPURunConfig Constructor. |
params |
Input |
Argument of model_fn, which is of the dictionary type. The key is the name of the argument, and the value is the basic Python type value. |
job_start_file |
Input |
Startup file path of the CSA job |
warm_start_from |
Input |
Path of the checkpoint. The checkpoint will be imported for training. |
Returns
An object of the NPUEstimator class
Examples
1 2 3 4 5 6 7 8 | from npu_bridge.npu_init import * ... self._classifier=NPUEstimator( model_fn=cnn_model_fn, model_dir=self._model_dir, config=tf.estimator.NPURunConfig( save_checkpoints_steps=50 if get_rank_id() == 0 else 0, keep_checkpoint_max=1)) |