Training Scenarios
Training Scenarios
Rec SDK TensorFlow supports two application scenarios: tf.Session training and NPUEstimator training.
- tf.Session training scenario The model is started using the new session instance, and a tensor example is returned for custom model training.
- NPUEstimator training scenario By encapsulating the control of different phases in machine learning, you can focus on the control of the network structure, without the need to repeatedly compile the training, evaluation, and prediction code for new machine learning tasks.
- Rec SDK TensorFlow does not support Keras.
- Rec SDK TensorFlow supports only model training script porting using native TensorFlow APIs. Third-party frameworks (such as tf_adapter, HugeCTR, and DeepRec) are not supported.
- Rec SDK TensorFlow supports only input data in tf.data.Dataset format.
- When loops and iterations per loop are enabled, the total number of training iterations must be an integer multiple of the value of iterations_per_loop.
Mapping Between TensorFlow and Rec SDK TensorFlow APIs
During model porting, determine whether to use sparse table-related APIs based on the actual model code and code context. If TensorFlow APIs related to sparse tables are used, change them to Rec SDK TensorFlow APIs. Table 1 describes the API mapping.
TensorFlow API |
Rec SDK TensorFlow API |
Function Description |
|---|---|---|
|
create_table |
Create a sparse table. |
|
sparse_lookup |
Query a sparse table. |
API example:
- TensorFlow example:
1 2 3 4 5 6
import tensorflow as tf from tensorflow.contrib.lookup import MutableHashTable # ....... user_id = features["user_ids"] user_emb_table = MutableHashTable(key_dtype=tf.int64, value_dtype=tf.float32, default_value=0.0) user_emb = user_emb_table.lookup(user_id)
- Rec SDK TensorFlow example:
1 2 3 4 5
import tensorflow as tf from mx_rec.core.embedding import create_table, sparse_lookup # ....... user_emb_table = create_table(key_dtype=tf.int64, value_dtype=tf.float32, name="user_table", dim=tf.Tensorshape([1]), emb_initializer=tf.compat.v1.truncated_normal_initializer(mean=10), device_vocabulary_size= 800000, host_vocabulary_size=0) user_emb = sparse_lookup(user_emb_table, feature_spec_list, batch_size*16, is_train=True, name=user_emb_table.table_name + "_lookup", modify_graph=False)
Parent topic: Porting and Training