- 初始化框架。在main.py中调用init接口,传入初始化框架需要的相关参数。相关参数请参见init。
# nbatch function needs to be used together with the prefetch and host_vocabulary_size != 0
init(max_steps=max_steps,
train_steps=TRAIN_steps,
eval_steps=EVAL_STEPS,
save_steps=SAVE_STEPS,
use_dynamic=use_dynamic,
use_dynamic_expansion=use_dynamic_expansion)
- 定义数据集。在main.py中调用get_asc_insert_func接口,创建数据集并对数据集进行预处理。相关参数请参见参数说明。
if not MODIFY_GRAPH_FLAG:
insert_fn = get_asc_insert_func(tgt_key_specs=feature_spec_list, is_training=is_training, dump_graph=dump_graph)
dataset = dataset.map(insert_fn)
dataset = dataset.prefetch(100)
iterator = dataset.make_initializable_iterator()
batch = iterator.get_next()
return batch, iterator
- 定义优化器。在optimizer.py中定义优化器,支持的优化器类型和相关参数请参见优化器。
# coding: UTF-8
import logging
import tensorflow as tf
from mx_rec.optimizers.lazy_adam import create_hash_optimizer
from mx_rec.optimizers.lazy_adam_by_addr import create_hash_optimizer_by_address
from mx_rec.util.initialize import get_use_dynamic_expansion
def get_dense_and_sparse_optimizer(cfg):
dense_optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=cfg.learning_rate)
use_dynamic_expansion = get_use_dynamic_expansion()
if use_dynamic_expansion:
sparse_optimizer = create_hash_optimizer_by_address(learning_rate=cfg.learning_rate)
logging.info("optimizer lazy_adam_by_addr")
else:
sparse_optimizer = create_hash_optimizer(learning_rate=cfg.learning_rate)
logging.info("optimizer lazy_adam")
return dense_optimizer, sparse_optimizer
- 建立稀疏表。在main.py中调用create_table接口,建立稀疏表,创建稀疏网络层。相关参数请参见参数说明。
user_hashtable = create_table(key_dtype=tf.int64,
dim=tf.TensorShape([cfg.user_hashtable_dim]),
name='user_table',
emb_initializer=tf.compat.v1.truncated_normal_initializer(),
device_vocabulary_size=cfg.user_vocab_size * 10,
host_vocabulary_size=0) # cfg.user_vocab_size * 100, # for h2d test
item_hashtable = create_table(key_dtype=tf.int64,
dim=tf.TensorShape([cfg.item_hashtable_dim]),
name='item_table',
emb_initializer=tf.compat.v1.truncated_normal_initializer(),
device_vocabulary_size=cfg.item_vocab_size * 10,
host_vocabulary_size=0) # cfg.user_vocab_size * 100, # for h2d test
- 建立计算图。传入稀疏网络层和特征列表,创建模型计算图,在计算图中调用sparse_lookup进行特征查询和误差计算。相关参数请参见参数说明。
def model_forward(input_list, batch, is_train, modify_graph, config_dict=None):
embedding_list = []
feature_list, hash_table_list, send_count_list = input_list
for feature, hash_table, send_count in zip(feature_list, hash_table_list, send_count_list):
access_and_evict_config = None
if isinstance(config_dict, dict):
access_and_evict_config = config_dict.get(hash_table.table_name)
embedding = sparse_lookup(hash_table, feature, send_count, is_train=is_train,
access_and_evict_config=access_and_evict_config,
name=hash_table.table_name + "_lookup", modify_graph=modify_graph, batch=batch)
reduced_embedding = tf.reduce_sum(embedding, axis=1, keepdims=False)
embedding_list.append(reduced_embedding)
my_model = MyModel()
my_model(embedding_list, batch["label_0"], batch["label_1"])
return my_model
- 定义梯度计算和优化过程。在main.py中调用get_dense_and_sparse_variable接口,得到密集网络层和稀疏网络层的参数,通过优化器计算梯度并执行优化。接口说明请参见get_dense_and_sparse_variable。
train_iterator, train_model = build_graph([user_hashtable, item_hashtable], is_train=True,
feature_spec_list=train_feature_spec_list,
config_dict=ACCESS_AND_EVICT, batch_number=cfg.batch_number)
eval_iterator, eval_model = build_graph([user_hashtable, item_hashtable], is_train=False,
feature_spec_list=eval_feature_spec_list,
config_dict=ACCESS_AND_EVICT, batch_number=cfg.batch_number)
dense_variables, sparse_variables = get_dense_and_sparse_variable()
- 启动数据加载和预处理。在main.py中调用modify_graph_and_start_emb_cache(改图模式)/start_asc_pipeline(非改图模式)接口,启动数据流水线(示例代码中使用if判断配置文件中的MODIFY_GRAPH_FLAG来控制是否使用改图模式)。接口说明请参见modify_graph_and_start_emb_cache。
saver = tf.compat.v1.train.Saver()
if MODIFY_GRAPH_FLAG:
logging.info("start to modifying graph")
modify_graph_and_start_emb_cache(dump_graph=True)
else:
start_asc_pipeline()
with tf.compat.v1.Session(config=sess_config(dump_data=False)) as sess:
if MODIFY_GRAPH_FLAG:
sess.run(get_initializer(True))
else:
sess.run(train_iterator.initializer)
sess.run(tf.compat.v1.global_variables_initializer())
EPOCH = 0
if os.path.exists(f"./saved-model/sparse-model-{rank_id}-%d" % 0):
saver.restore(sess, f"./saved-model/model-{rank_id}-%d" % 0)
else:
saver.save(sess, f"./saved-model/model-{rank_id}", global_step=0)
for i in range(1, 201):
logging.info(f"################ training at step {i} ################")
try:
sess.run([train_ops, train_model.loss_list])
except tf.errors.OutOfRangeError:
logging.info(f"Encounter the end of Sequence for training.")
break
else:
if i % TRAIN_INTERVAL == 0:
EPOCH += 1
evaluate()
if i % SAVING_INTERVAL == 0:
saver.save(sess, f"./saved-model/model-{rank_id}", global_step=i)
saver.save(sess, f"./saved-model/model-{rank_id}", global_step=i)
- 启动Session计算并在训练过程中保存模型。在main.py中调用saver接口,启动Session计算并在训练过程中保存模型。
saver = tf.compat.v1.train.Saver()
if MODIFY_GRAPH_FLAG:
logging.info("start to modifying graph")
modify_graph_and_start_emb_cache(dump_graph=True)
else:
start_asc_pipeline()
with tf.compat.v1.Session(config=sess_config(dump_data=False)) as sess:
if MODIFY_GRAPH_FLAG:
sess.run(get_initializer(True))
else:
sess.run(train_iterator.initializer)
sess.run(tf.compat.v1.global_variables_initializer())
EPOCH = 0
if os.path.exists(f"./saved-model/sparse-model-{rank_id}-%d" % 0):
saver.restore(sess, f"./saved-model/model-{rank_id}-%d" % 0)
else:
saver.save(sess, f"./saved-model/model-{rank_id}", global_step=0)
for i in range(1, 201):
logging.info(f"################ training at step {i} ################")
try:
sess.run([train_ops, train_model.loss_list])
except tf.errors.OutOfRangeError:
logging.info(f"Encounter the end of Sequence for training.")
break
else:
if i % TRAIN_INTERVAL == 0:
EPOCH += 1
evaluate()
if i % SAVING_INTERVAL == 0:
saver.save(sess, f"./saved-model/model-{rank_id}", global_step=i)
saver.save(sess, f"./saved-model/model-{rank_id}", global_step=i)
- 关闭数据流释放资源。在main.py中调用terminate_config_initializer,关闭数据流释放资源。接口说明请参见terminate_config_initializer。
terminate_config_initializer()
logging.info("Demo done!")