如果一个工作进程稍后关闭会话,则tensorflow工作进程将以集群模式挂起

时间:2018-10-31 05:58:16

标签: tensorflow

以该问题中的代码为例,将代码另存为trainer.py,您可以尝试使用以下命令启动两个worker和一个ps:

python trainer.py \
     --ps_hosts=localhost:2222 \
     --worker_hosts=localhost:2223,localhost:2224 \
     --job_name=worker --task_index=0

python trainer.py \
     --ps_hosts=localhost:2222 \
     --worker_hosts=localhost:2223,localhost:2224 \
     --job_name=worker --task_index=1

python trainer.py \
     --ps_hosts=localhost:2222 \
     --worker_hosts=localhost:2223,localhost:2224 \
     --job_name=ps --task_index=0

火车结束后,如果我删除以下代码,工作人员0将挂起并且不会退出

    if is_chief:
        import time
        time.sleep(2)

然后所有工作人员将同时离开。似乎所有工作人员都必须在同一时间关闭会话。这是预期的还是我使用不正确?

import argparse
import sys

import tensorflow as tf
import numpy as np

FLAGS = None


def main(_):
  ps_hosts = FLAGS.ps_hosts.split(",")
  worker_hosts = FLAGS.worker_hosts.split(",")

  # Create a cluster from the parameter server and worker hosts.
  cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})

  # Create and start a server for the local task.
  server = tf.train.Server(cluster,
                           job_name=FLAGS.job_name,
                           task_index=FLAGS.task_index)

  if FLAGS.job_name == "ps":
    server.join()
  elif FLAGS.job_name == "worker":

    train_X = np.linspace(-1.0, 1.0, 100)
    train_Y = 2.0 * train_X + np.random.randn(*train_X.shape) * 0.33 + 10.0

    X = tf.placeholder("float")
    Y = tf.placeholder("float")
    # Assigns ops to the local worker by default.
    with tf.device(tf.train.replica_device_setter(
        worker_device="/job:worker/task:%d" % FLAGS.task_index,
        cluster=cluster)):

      # Build model...
      w = tf.Variable(0.0, name="weight")
      b = tf.Variable(0.0, name="bias")
      loss =  tf.square(Y - tf.multiply(X, w) - b)
      global_step = tf.contrib.framework.get_or_create_global_step()

      train_op = tf.train.AdagradOptimizer(0.01).minimize(
          loss, global_step=global_step)

    # The StopAtStepHook handles stopping after running given steps.
    hooks=[tf.train.StopAtStepHook(last_step=10000)]

    is_chief = (FLAGS.task_index == 0)
    # The MonitoredTrainingSession takes care of session initialization,
    # restoring from a checkpoint, saving to a checkpoint, and closing when done
    # or an error occurs.
    with tf.train.MonitoredTrainingSession(master=server.target,
                                           is_chief=is_chief,
                                           checkpoint_dir="/tmp/train_logs",
                                           hooks=hooks) as mon_sess:
      while not mon_sess.should_stop():
        # Run a training step asynchronously.
        # See <a href="./../api_docs/python/tf/train/SyncReplicasOptimizer"><code>tf.train.SyncReplicasOptimizer</code></a> for additional details on how to
        # perform *synchronous* training.
        # mon_sess.run handles AbortedError in case of preempted PS.
        for (x, y) in zip(train_X, train_Y):
            if mon_sess.should_stop():
                break
            _, step, loss_value = mon_sess.run([train_op, global_step, loss], feed_dict={X: x,
                                                                       Y: y})

            print("Step: {}, loss: {}".format(step, loss_value))
        if is_chief:
            import time
            time.sleep(2)

if __name__ == "__main__":
  parser = argparse.ArgumentParser()
  parser.register("type", "bool", lambda v: v.lower() == "true")
  # Flags for defining the tf.train.ClusterSpec
  parser.add_argument(
      "--ps_hosts",
      type=str,
      default="",
      help="Comma-separated list of hostname:port pairs"
  )
  parser.add_argument(
      "--worker_hosts",
      type=str,
      default="",
      help="Comma-separated list of hostname:port pairs"
  )
  parser.add_argument(
      "--job_name",
      type=str,
      default="",
      help="One of 'ps', 'worker'"
  )
  # Flags for defining the tf.train.Server
  parser.add_argument(
      "--task_index",
      type=int,
      default=0,
      help="Index of task within the job"
  )
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

0 个答案:

没有答案