I want to setup a distributed tensorflow model, but fail to understand how MonitoredTrainingSession & StopAtStepHook interact. Before I had this setup:
for epoch in range(training_epochs):
for i in range(total_batch-1):
c, p, s = sess.run([cost, prediction, summary_op], feed_dict={x: batch_x, y: batch_y})
Now I have this setup (simplified):
def run_nn_model(learning_rate, log_param, optimizer, batch_size, layer_config):
with tf.device(tf.train.replica_device_setter(
worker_device="/job:worker/task:%d" % mytaskid,
cluster=cluster)):
# [variables...]
hooks=[tf.train.StopAtStepHook(last_step=100)]
if myjob == "ps":
server.join()
elif myjob == "worker":
with tf.train.MonitoredTrainingSession(master = server.target,
is_chief=(mytaskid==0),
checkpoint_dir='/tmp/train_logs',
hooks=hooks
) as sess:
while not sess.should_stop():
#for epoch in range...[see above]
Is this wrong? It throws:
RuntimeError: Run called even after should_stop requested.
Command exited with non-zero status 1
Can somebody explain to me how tensorflow is coordinating here? How can I use the stepcounter to keep track of the training? (before I had this handy epoch variable)
答案 0 :(得分:2)
每次执行sess.run时,计数器都会递增。这里的问题是您运行的步骤(total_batch-1 x training_epochs)
多于钩子(200)
中指定的步骤数。
你能做什么,即使我认为它不是一个干净的语法,也要定义last_step = total_batch-1 x training_epochs
。