在MonitoredTrainingSession

时间:2017-12-29 19:15:23

标签: python tensorflow

我想尝试MonitoredTrainingSession,但我也使用了几个数据集对象来训练和验证集。并选择正确的一个,因为manual建议我使用字符串句柄。但是为了在训练时将句柄传递给feed_dict,我需要先评估它。像这样:

handle = sess.run(iterator.string_handle())

但是当我在MonitoredTrainingSession的上下文中执行此操作时,我收到错误:

RuntimeError: Graph is finalized and cannot be modified.

正如我想的那样,出路是为init_fn对象创建一个Scaffold,我将其传递给会话。但这并没有成功。如果我尝试在上下文中运行前面提到的代码init_fn仍然会得到相同的错误。

正如文档中提到的init_fn

  

在init op执行其他初始化之后运行的callable。

这让我觉得我对这个回调的预期目的完全错了,或者Tensorflow行为不端。

你能帮我解决这个困惑。

我的张量流版本为1.4.0

更新

添加一个最小的例子。第一个块有效,第二个块没有。

import tensorflow as tf

dataset_a = tf.data.Dataset.range(10)
dataset_b = tf.data.Dataset.range(20, 25)

input_handle = tf.placeholder(tf.string, shape=())
input_iterator = tf.data.Iterator.from_string_handle(
    input_handle, dataset_a.output_types, dataset_a.output_shapes)

x = input_iterator.get_next()
plus_one = tf.add(x, 1)

with tf.Session() as sess:
    iterator = dataset_b.make_initializable_iterator()
    handle = sess.run(iterator.string_handle())
    sess.run(iterator.initializer)

    res = sess.run(plus_one, feed_dict={input_handle: handle})
    print(res)

with tf.train.MonitoredTrainingSession() as sess:
    iterator = dataset_a.make_initializable_iterator()
    handle = sess.run(iterator.string_handle())
    sess.run(iterator.initializer)

    res = sess.run(plus_one, feed_dict={input_handle: handle})
    print(res)

1 个答案:

答案 0 :(得分:0)

我找到了问题的答案。我的想法是我不得不放弃使用句柄而是创建几个迭代器初始化器(每个数据集)。

解决方案如下:

import tensorflow as tf

dataset_a = tf.data.Dataset.range(10)
dataset_b = tf.data.Dataset.range(20, 25)

input_handle = tf.placeholder(tf.string, shape=(), name='input')
input_iterator = tf.data.Iterator.from_string_handle(
    input_handle, dataset_a.output_types, dataset_a.output_shapes)

x = input_iterator.get_next()
plus_one = tf.add(x, 1)

with tf.Session() as sess:
    iterator = dataset_b.make_initializable_iterator()
    handle = sess.run(iterator.string_handle())
    sess.run(iterator.initializer)

    res = sess.run(plus_one, feed_dict={input_handle: handle})
    print(res)


iterator = dataset_a.make_initializable_iterator()

iterator_init_op_a = iterator.make_initializer(dataset_a)
iterator_init_op_b = iterator.make_initializer(dataset_b)

x = iterator.get_next()
plus_one = tf.add(x, 1)

with tf.train.MonitoredTrainingSession() as sess:
    sess.run(iterator_init_op_a)
    res = sess.run(plus_one)
    print(res)
    res = sess.run(plus_one)
    print(res)
    sess.run(iterator_init_op_b)
    res = sess.run(plus_one)
    print(res)
    res = sess.run(plus_one)
    print(res)

如果数据集依赖于其他数据(例如,在this部分中),我可以在评估时将所需数据另外提供给iterator_init_op_a