我在单个会话环境中一直使用Tensorflow中的cudnn_rnn模型,它们运行正常。但是,当我尝试在具有1个PS主机和几个GPU工作线程的分布式运行中使用cudnnLSTM时,Tensorflow崩溃。
from tensorflow.contrib.cudnn_rnn.python.layers import cudnn_rnn
with tf.device(tf.train.replica_device_setter(
worker_device = "/job:worker/task:%d" % TASK_INDEX, cluster = cluster)):
lstm = cudnn_rnn.CudnnLSTM(self.layers, self.hidden_units)
with tf.train.MonitoredTrainingSession(master = server.target,
is_chief = (TASK_INDEX == 0),
checkpoint_dir = CHECKPOINT_DIR,
hooks = hooks) as sess:
...
我的一个工作进程(可以访问GPU)中出现以下错误:
InvalidArgumentError (see above for traceback): Cannot assign a device for operation 'save/CudnnRNNCanonicalToParams': Could not satisfy explicit device specification '/job:worker/task:0/device:CPU:0' because no supported kernel for CPU devices is available.
[[Node: save/CudnnRNNCanonicalToParams = CudnnRNNCanonicalToParams[T=DT_FLOAT, direction="unidirectional", dropout=0, input_mode="linear_input", num_params=12, rnn_mode="gru", seed=0, seed2=0, _device="/job:worker/task:0/device:CPU:0"](save/CudnnRNNCanonicalToParams/num_layers, save/CudnnRNNCanonicalToParams/num_units, save/CudnnRNNCanonicalToParams/input_size, save/Reshape, save/Reshape_1, save/Reshape_2, save/Reshape_3, save/Reshape_4, save/Reshape_5, save/Reshape_6, save/Reshape_7, save/Reshape_8, save/Reshape_9, save/Reshape_10, save/Reshape_11, save/split_3, save/split_3:1, save/RestoreV2_22, save/split_4, save/split_4:1, save/RestoreV2_23, save/split_8, save/split_8:1, save/RestoreV2_25, save/split_9, save/split_9:1, save/RestoreV2_26)]]
我尝试在save_checkpoint_secs = None
中设置MonitoredTrainingSession
,但仍会遇到同样的错误。
我已阅读tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py
中提及保存参数和使用PS服务器的评论,但无法找到有效的示例。关于如何使分布式张量流和cudnnLSTM一起工作的任何想法?
更新 @ Ash关于更新张量流的答案有所帮助。此外,目前,我需要在Saver中指定无分片:
with tf.train.MonitoredTrainingSession(master = server.target,
is_chief = (TASK_INDEX == 0),
checkpoint_dir = CHECKPOINT_DIR,
scaffold = tf.train.Scaffold(
saver = tf.train.Saver(sharded = False, allow_empty = True)),
hooks = hooks) as sess:
答案 0 :(得分:1)
我相信这是一个已在HEAD中修复的错误,但该修复程序尚未发布任何版本,因此要获得修复,您必须从源代码构建TensorFlow,或以某种方式在安装中包含相同的修复程序
此修复位于此提交中: https://github.com/tensorflow/tensorflow/commit/56da08fed6862422904411a61059b38940a57338