昇腾社区首页
中文
注册

训练场景介绍

训练场景介绍

Rec SDK TensorFlow提供使用tf.Session训练场景和NPUEstimator训练场景两种使用场景。

  • tf.Session训练场景。通过新建的Session实例启动模型运行,返回Tensor示例,进行定制化模型训练。
  • NPUEstimator训练场景。基于对机器学习不同阶段的控制的封装,用户无需不断地为新机器学习任务重复编写训练、评估、预测的代码,可以专注于对网络结构的控制。
  • Rec SDK TensorFlow暂时不支持Keras。
  • 目前Rec SDK TensorFlow仅支持使用TensorFlow原生API模型训练脚本迁移,不支持使用第三方框架(tf_adapter、HugeCTR、DeepRec等)。
  • Rec SDK TensorFlow目前仅支持模型的输入数据为tf.data.Dataset格式。
  • 启用大小循环的情况下,训练迭代的总次数必须是小循环(即iterations_per_loop)的整数倍。

TensorFlowRec SDK TensorFlow接口对应关系

在进行模型迁移时,需要根据实际的模型代码以及代码上下文判断是否使用到稀疏表相关的接口,如果是与稀疏表相关的TensorFlow接口,需要修改为Rec SDK TensorFlow的接口,接口对应关系如表1所示。

表1 接口对应关系

TensorFlow接口

Rec SDK TensorFlow接口

接口功能描述

  • MutableHashTable
  • tf.Variable

create_table

创建稀疏表

  • tf.embedding_lookup
  • mutable_hash_table.lookup(mutable_hash_table是MutableHashTable的实例)等

sparse_lookup

查询稀疏表

接口示例:

  • TensorFlow示例:
    1
    2
    3
    4
    5
    6
    import tensorflow as tf
    from tensorflow.contrib.lookup import MutableHashTable
    # .......
    user_id = features["user_ids"]
    user_emb_table = MutableHashTable(key_dtype=tf.int64, value_dtype=tf.float32, default_value=0.0)
    user_emb = user_emb_table.lookup(user_id)
    
  • Rec SDK TensorFlow示例:
    1
    2
    3
    4
    5
    import tensorflow as tf
    from mx_rec.core.embedding import create_table, sparse_lookup
    # .......
    user_emb_table = create_table(key_dtype=tf.int64, value_dtype=tf.float32, name="user_table", dim=tf.Tensorshape([1]),             emb_initializer=tf.compat.v1.truncated_normal_initializer(mean=10), device_vocabulary_size= 800000, host_vocabulary_size=0)
    user_emb = sparse_lookup(user_emb_table, feature_spec_list, batch_size*16, is_train=True, name=user_emb_table.table_name + "_lookup", modify_graph=False)