def create_iteration_per_loop_var(self, train_op)
该接口和load_iteration_per_loop_var接口配合使用,用来实现sess.run模式下设置小循环次数,即每次sess.run()在Device侧执行训练迭代的次数。该接口的主要作用是修改图,并通过load_iteration_per_loop_var接口来设置小循环次数。
参数名 |
输入/输出 |
描述 |
---|---|---|
train_op |
输入 |
更新变量或梯度的操作。 |
返回一个算子,供用户通过sess.run(op)调用。
# 训练模型 with tf.Session(config=config) as sess: sess.run(init) # sess.run模式下设置小循环次数为10 iteration = util.IterationPerLoop() train_op = iteration.create_iteration_per_loop_var(optimizer) #修改图 tf.train.Supervisor(logdir="/home/xxxx",init_op=init) #冻结图 iteration.load_iteration_per_loop_var(sess, 10) #设置小循环次数 for epoch in range(training_epochs): avg_cost = 0 total_batch = int(mnist.train.num_examples / batch_size) for i in range(total_batch): batch_xs, batch_ys = mnist.train.next_batch(batch_size) _, c = sess.run([train_op, cost], feed_dict={x: batch_xs, y: batch_ys}) avg_cost += c / total_batch