以该问题中的代码为例,将代码另存为trainer.py,您可以尝试使用以下命令启动两个worker和一个ps:
python trainer.py \
--ps_hosts=localhost:2222 \
--worker_hosts=localhost:2223,localhost:2224 \
--job_name=worker --task_index=0
python trainer.py \
--ps_hosts=localhost:2222 \
--worker_hosts=localhost:2223,localhost:2224 \
--job_name=worker --task_index=1
python trainer.py \
--ps_hosts=localhost:2222 \
--worker_hosts=localhost:2223,localhost:2224 \
--job_name=ps --task_index=0
火车结束后,如果我删除以下代码,工作人员0将挂起并且不会退出
if is_chief:
import time
time.sleep(2)
然后所有工作人员将同时离开。似乎所有工作人员都必须在同一时间关闭会话。这是预期的还是我使用不正确?
import argparse
import sys
import tensorflow as tf
import numpy as np
FLAGS = None
def main(_):
ps_hosts = FLAGS.ps_hosts.split(",")
worker_hosts = FLAGS.worker_hosts.split(",")
# Create a cluster from the parameter server and worker hosts.
cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
# Create and start a server for the local task.
server = tf.train.Server(cluster,
job_name=FLAGS.job_name,
task_index=FLAGS.task_index)
if FLAGS.job_name == "ps":
server.join()
elif FLAGS.job_name == "worker":
train_X = np.linspace(-1.0, 1.0, 100)
train_Y = 2.0 * train_X + np.random.randn(*train_X.shape) * 0.33 + 10.0
X = tf.placeholder("float")
Y = tf.placeholder("float")
# Assigns ops to the local worker by default.
with tf.device(tf.train.replica_device_setter(
worker_device="/job:worker/task:%d" % FLAGS.task_index,
cluster=cluster)):
# Build model...
w = tf.Variable(0.0, name="weight")
b = tf.Variable(0.0, name="bias")
loss = tf.square(Y - tf.multiply(X, w) - b)
global_step = tf.contrib.framework.get_or_create_global_step()
train_op = tf.train.AdagradOptimizer(0.01).minimize(
loss, global_step=global_step)
# The StopAtStepHook handles stopping after running given steps.
hooks=[tf.train.StopAtStepHook(last_step=10000)]
is_chief = (FLAGS.task_index == 0)
# The MonitoredTrainingSession takes care of session initialization,
# restoring from a checkpoint, saving to a checkpoint, and closing when done
# or an error occurs.
with tf.train.MonitoredTrainingSession(master=server.target,
is_chief=is_chief,
checkpoint_dir="/tmp/train_logs",
hooks=hooks) as mon_sess:
while not mon_sess.should_stop():
# Run a training step asynchronously.
# See <a href="./../api_docs/python/tf/train/SyncReplicasOptimizer"><code>tf.train.SyncReplicasOptimizer</code></a> for additional details on how to
# perform *synchronous* training.
# mon_sess.run handles AbortedError in case of preempted PS.
for (x, y) in zip(train_X, train_Y):
if mon_sess.should_stop():
break
_, step, loss_value = mon_sess.run([train_op, global_step, loss], feed_dict={X: x,
Y: y})
print("Step: {}, loss: {}".format(step, loss_value))
if is_chief:
import time
time.sleep(2)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
# Flags for defining the tf.train.ClusterSpec
parser.add_argument(
"--ps_hosts",
type=str,
default="",
help="Comma-separated list of hostname:port pairs"
)
parser.add_argument(
"--worker_hosts",
type=str,
default="",
help="Comma-separated list of hostname:port pairs"
)
parser.add_argument(
"--job_name",
type=str,
default="",
help="One of 'ps', 'worker'"
)
# Flags for defining the tf.train.Server
parser.add_argument(
"--task_index",
type=int,
default=0,
help="Index of task within the job"
)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)