使用CudnnLSTM的分布式Tensorflow

时间:2018-01-10 22:46:17

标签: tensorflow

我在单个会话环境中一直使用Tensorflow中的cudnn_rnn模型,它们运行正常。但是,当我尝试在具有1个PS主机和几个GPU工作线程的分布式运行中使用cudnnLSTM时,Tensorflow崩溃。

from tensorflow.contrib.cudnn_rnn.python.layers import cudnn_rnn
with tf.device(tf.train.replica_device_setter(
  worker_device = "/job:worker/task:%d" % TASK_INDEX, cluster = cluster)):
    lstm  = cudnn_rnn.CudnnLSTM(self.layers, self.hidden_units)
with tf.train.MonitoredTrainingSession(master   = server.target,
                                       is_chief = (TASK_INDEX == 0),
                                       checkpoint_dir = CHECKPOINT_DIR,
                                       hooks    = hooks) as sess:
    ...

我的一个工作进程(可以访问GPU)中出现以下错误:

InvalidArgumentError (see above for traceback): Cannot assign a device for operation 'save/CudnnRNNCanonicalToParams': Could not satisfy explicit device specification '/job:worker/task:0/device:CPU:0' because no supported kernel for CPU devices is available.
 [[Node: save/CudnnRNNCanonicalToParams = CudnnRNNCanonicalToParams[T=DT_FLOAT, direction="unidirectional", dropout=0, input_mode="linear_input", num_params=12, rnn_mode="gru", seed=0, seed2=0, _device="/job:worker/task:0/device:CPU:0"](save/CudnnRNNCanonicalToParams/num_layers, save/CudnnRNNCanonicalToParams/num_units, save/CudnnRNNCanonicalToParams/input_size, save/Reshape, save/Reshape_1, save/Reshape_2, save/Reshape_3, save/Reshape_4, save/Reshape_5, save/Reshape_6, save/Reshape_7, save/Reshape_8, save/Reshape_9, save/Reshape_10, save/Reshape_11, save/split_3, save/split_3:1, save/RestoreV2_22, save/split_4, save/split_4:1, save/RestoreV2_23, save/split_8, save/split_8:1, save/RestoreV2_25, save/split_9, save/split_9:1, save/RestoreV2_26)]]

我尝试在save_checkpoint_secs = None中设置MonitoredTrainingSession,但仍会遇到同样的错误。

我已阅读tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py中提及保存参数和使用PS服务器的评论,但无法找到有效的示例。关于如何使分布式张量流和cudnnLSTM一起工作的任何想法?

更新 @ Ash关于更新张量流的答案有所帮助。此外,目前,我需要在Saver中指定无分片:

   with tf.train.MonitoredTrainingSession(master   = server.target,
                                          is_chief = (TASK_INDEX == 0),
                                          checkpoint_dir = CHECKPOINT_DIR,
                                          scaffold = tf.train.Scaffold(
                                                saver = tf.train.Saver(sharded = False, allow_empty = True)),
                                          hooks    = hooks) as sess:

1 个答案:

答案 0 :(得分:1)

我相信这是一个已在HEAD中修复的错误,但该修复程序尚未发布任何版本,因此要获得修复,您必须从源代码构建TensorFlow,或以某种方式在安装中包含相同的修复程序

此修复位于此提交中:  https://github.com/tensorflow/tensorflow/commit/56da08fed6862422904411a61059b38940a57338