所以我试图调整分布式张量流示例以适应Keras。我不确定这是完全可能的,但我想我会尝试一下。这是我启动工人的脚本
import tensorflow as tf
import keras
from keras import backend as K
from keras.layers import Dense, Dropout, InputLayer
from keras.models import Sequential
from keras.datasets import mnist
def get_data():
(x_train, y_train), (x_test, y_test) = mnist.load_data()
num_classes = 10
x_train = x_train.reshape(60000, 784)
x_test = x_test.reshape(10000, 784)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
return x_train, y_train
tf.app.flags.DEFINE_string("ps_hosts", "", "List of parameter server addresses'")
tf.app.flags.DEFINE_string("worker_hosts", "", "List of worker addresses'")
tf.app.flags.DEFINE_integer("task_index", "", "Index of task within the job")
FLAGS = tf.app.flags.FLAGS
task_index = FLAGS.task_index
ps_hosts = FLAGS.ps_hosts.split(",")
worker_hosts = FLAGS.worker_hosts.split(",")
cluster_spec = tf.train.ClusterSpec({'ps' : ps_hosts, 'worker' : worker_hosts})
server = tf.train.Server(cluster_spec, job_name='worker', task_index=task_index)
worker_device = 'job:worker/task:{}'.format(task_index)
replica_device = tf.train.replica_device_setter(
worker_device = worker_device,
cluster = cluster_spec)
x_train, y_train = get_data()
with tf.device(replica_device):
K.manual_variable_initialization(True)
global_step = tf.get_variable(name='global_step', shape=[],
initializer=tf.constant_initializer(0), trainable=False)
init_op = tf.global_variables_initializer()
supervisor = tf.train.Supervisor(is_chief=(task_index == 0),
global_step=global_step, init_op=init_op)
sess = supervisor.prepare_or_wait_for_session(server.target)
K.set_session(sess)
model = Sequential()
model.add(Dense(512, activation='relu', input_shape=(784,)))
model.add(Dropout(0.2))
model.add(Dense(512, activation='relu'))
model.add(Dense(10, activation='softmax'))
model.compile(optimizer='sgd', loss='categorical_crossentropy')
model.fit(x_train[:1000], y_train[:1000], epochs=1)
这是我启动参数服务器的脚本
import tensorflow as tf
if __name__ == '__main__':
tf.app.flags.DEFINE_string("ps_hosts", "", "List of parameter server addresses'")
tf.app.flags.DEFINE_string("worker_hosts", "", "List of worker addresses'")
tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")
FLAGS = tf.app.flags.FLAGS
task_index = FLAGS.task_index
ps_hosts = FLAGS.ps_hosts.split(",")
worker_hosts = FLAGS.worker_hosts.split(",")
cluster = { 'ps' : ps_hosts, 'worker' : worker_hosts }
cluster = {"ps" : ps_hosts, "worker" : worker_hosts}
server = tf.train.Server(cluster, job_name='ps', task_index=task_index)
server.join()
这是我运行这些过程的脚本
#!/bin/bash
python3 start_ps.py --job_name=ps --ps_hosts=localhost:2222 --worker_hosts=localhost:2223,localhost:2224 --task_index=0 &
python3 start_worker.py --ps_hosts=localhost:2222 --worker_hosts=localhost:2223,localhost:2224 --task_index=0 &
python3 start_worker.py --ps_hosts=localhost:2222 --worker_hosts=localhost:2223,localhost:2224 --task_index=1
当我运行此操作时,我收到错误
Traceback (most recent call last):
File "start_worker.py", line 62, in <module>
model.add(Dense(512, activation='relu', input_shape=(784,)))
File "/Users/user/projects/keras/keras/models.py", line 421, in add
dtype=layer.dtype, name=layer.name + '_input')
File "/Users/user/projects/keras/keras/engine/topology.py", line 1375, in Input
input_tensor=tensor)
File "/Users/pzx496/projects/keras/keras/engine/topology.py", line 1286, in __init__
name=self.name)
File "/Users/pzx496/projects/keras/keras/backend/tensorflow_backend.py", line 349, in placeholder
x = tf.placeholder(dtype, shape=shape, name=name)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/array_ops.py", line 1530, in placeholder
return gen_array_ops._placeholder(dtype=dtype, shape=shape, name=name)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py", line 1954, in _placeholder
name=name)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 767, in apply_op
op_def=op_def)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2458, in create_op
self._check_not_finalized()
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2181, in _check_not_finalized
raise RuntimeError("Graph is finalized and cannot be modified.")
RuntimeError: Graph is finalized and cannot be modified.
我想通过设置K.manual_variable_initialization(True),Keras不会初始化变种,直到我告诉它。情况似乎并非如此。知道如何解决这个问题吗?