Training Scenarios

Training Scenarios

Rec SDK TensorFlow supports two application scenarios: tf.Session training and NPUEstimator training.

  • tf.Session training scenario The model is started using the new session instance, and a tensor example is returned for custom model training.
  • NPUEstimator training scenario By encapsulating the control of different phases in machine learning, you can focus on the control of the network structure, without the need to repeatedly compile the training, evaluation, and prediction code for new machine learning tasks.
  • Rec SDK TensorFlow does not support Keras.
  • Rec SDK TensorFlow supports only model training script porting using native TensorFlow APIs. Third-party frameworks (such as tf_adapter, HugeCTR, and DeepRec) are not supported.
  • Rec SDK TensorFlow supports only input data in tf.data.Dataset format.
  • When loops and iterations per loop are enabled, the total number of training iterations must be an integer multiple of the value of iterations_per_loop.

Mapping Between TensorFlow and Rec SDK TensorFlow APIs

During model porting, determine whether to use sparse table-related APIs based on the actual model code and code context. If TensorFlow APIs related to sparse tables are used, change them to Rec SDK TensorFlow APIs. Table 1 describes the API mapping.

Table 1 API mapping

TensorFlow API

Rec SDK TensorFlow API

Function Description

  • MutableHashTable
  • tf.Variable

create_table

Create a sparse table.

  • tf.embedding_lookup
  • mutable_hash_table.lookup (mutable_hash_table is an instance of MutableHashTable)

sparse_lookup

Query a sparse table.

API example:

  • TensorFlow example:
    1
    2
    3
    4
    5
    6
    import tensorflow as tf
    from tensorflow.contrib.lookup import MutableHashTable
    # .......
    user_id = features["user_ids"]
    user_emb_table = MutableHashTable(key_dtype=tf.int64, value_dtype=tf.float32, default_value=0.0)
    user_emb = user_emb_table.lookup(user_id)
    
  • Rec SDK TensorFlow example:
    1
    2
    3
    4
    5
    import tensorflow as tf
    from mx_rec.core.embedding import create_table, sparse_lookup
    # .......
    user_emb_table = create_table(key_dtype=tf.int64, value_dtype=tf.float32, name="user_table", dim=tf.Tensorshape([1]),             emb_initializer=tf.compat.v1.truncated_normal_initializer(mean=10), device_vocabulary_size= 800000, host_vocabulary_size=0)
    user_emb = sparse_lookup(user_emb_table, feature_spec_list, batch_size*16, is_train=True, name=user_emb_table.table_name + "_lookup", modify_graph=False)