如何使用Tensorflow Dataset API中的可馈送迭代器以及MonitoredTrainingSession?

时间:2017-09-08 07:34:51

标签: tensorflow tensorflow-datasets

a ticket建议使用可馈送迭代器在训练和验证数据集之间切换,而无需重新初始化迭代器。它主要需要喂食手柄以便在它们之间进行选择。

如何与tf.train.MonitoredTrainingSession一起使用?

以下方法失败并显示“RuntimeError:Graph已完成且无法修改”。错误。

[#EnumOrder]

如何实现MonitoredTrainingSession的便利性以及同时迭代训练和验证数据集?

3 个答案:

答案 0 :(得分:5)

我从Tensorflow GitHub问题得到答案 - https://github.com/tensorflow/tensorflow/issues/12859

解决方案是在创建iterator.string_handle()之前调用MonitoredSession

import tensorflow as tf
from tensorflow.contrib.data import Dataset, Iterator

dataset_train = Dataset.range(10)
dataset_val = Dataset.range(90, 100)

iter_train_handle = dataset_train.make_one_shot_iterator().string_handle()
iter_val_handle = dataset_val.make_one_shot_iterator().string_handle()

handle = tf.placeholder(tf.string, shape=[])
iterator = Iterator.from_string_handle(
    handle, dataset_train.output_types, dataset_train.output_shapes)
next_batch = iterator.get_next()

with tf.train.MonitoredTrainingSession() as sess:
    handle_train, handle_val = sess.run([iter_train_handle, iter_val_handle])

    for step in range(10):
        print('train', sess.run(next_batch, feed_dict={handle: handle_train}))

        if step % 3 == 0:
            print('val', sess.run(next_batch, feed_dict={handle: handle_val}))

Output:
('train', 0)
('val', 90)
('train', 1)
('train', 2)
('val', 91)
('train', 3)

答案 1 :(得分:1)

@Michael Jaison G回答是正确的。但是,当您还想使用某些需要评估图形部分的session_run_hook时,它不起作用,例如, LoggingTensorHook或SummarySaverHook。 以下示例将导致错误:

import tensorflow as tf

dataset_train = tf.data.Dataset.range(10)
dataset_val = tf.data.Dataset.range(90, 100)

iter_train_handle = dataset_train.make_one_shot_iterator().string_handle()
iter_val_handle = dataset_val.make_one_shot_iterator().string_handle()

handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
    handle, dataset_train.output_types, dataset_train.output_shapes)
feature = iterator.get_next()

pred = feature * feature
tf.summary.scalar('pred', pred)
global_step = tf.train.create_global_step()

summary_hook = tf.train.SummarySaverHook(save_steps=5,
                                         output_dir="summaries", summary_op=tf.summary.merge_all())

with tf.train.MonitoredTrainingSession(hooks=[summary_hook]) as sess: 
    handle_train, handle_val = sess.run([iter_train_handle, iter_val_handle])

    for step in range(10):
        feat = sess.run(feature, feed_dict={handle: handle_train})
        pred_ = sess.run(pred, feed_dict={handle: handle_train})
        print('train: ', feat)
        print('pred: ', pred_)

        if step % 3 == 0:
            print('val', sess.run(feature, feed_dict={handle: handle_val}))

这将失败并显示错误:

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'Placeholder' with dtype string
     [[Node: Placeholder = Placeholder[dtype=DT_STRING, shape=[], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
     [[Node: cond/Switch_1/_15 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation=1, tensor_name="edge_18_cond/Switch_1", tensor_type=DT_INT64, _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]]

原因是钩子会尝试在第一个session.run([iter_train_handle,iter_val_handle])上评估图形,它显然还没有在feed_dict中包含句柄。

解决方法是覆盖导致问题的钩子并更改before_run和after_run中的代码,以仅评估包含feed_dict中句柄的session.run调用(您可以访问当前session.run调用的feed_dict)通过before_run和after_run的run_context参数

或者您可以使用Tensorflow的最新版本(1.4之后),它将一个run_step_fn函数添加到MonitoredSession,它允许您指定以下step_fn,这将避免错误(以评估if语句TrainingIteration次数为代价) ...)

def step_fn(step_context):
  if handle_train is None:
    handle_train, handle_val = sess.run([iter_train_handle, iter_val_handle])
  return step_context.run_with_hooks(fetches=..., feed_dict=...)

答案 2 :(得分:1)

有一个使用SessionRunHook在mot_session中使用占位符的演示。 这个演示是关于通过提供diff handle_string来切换数据集。

顺便说一下,我已经尝试了所有解决方案,但只有这样才有效。

dataset_switching