我想尝试MonitoredTrainingSession
,但我也使用了几个数据集对象来训练和验证集。并选择正确的一个,因为manual建议我使用字符串句柄。但是为了在训练时将句柄传递给feed_dict
,我需要先评估它。像这样:
handle = sess.run(iterator.string_handle())
但是当我在MonitoredTrainingSession
的上下文中执行此操作时,我收到错误:
RuntimeError: Graph is finalized and cannot be modified.
正如我想的那样,出路是为init_fn
对象创建一个Scaffold
,我将其传递给会话。但这并没有成功。如果我尝试在上下文中运行前面提到的代码init_fn
仍然会得到相同的错误。
正如文档中提到的init_fn
:
在init op执行其他初始化之后运行的callable。
这让我觉得我对这个回调的预期目的完全错了,或者Tensorflow行为不端。
你能帮我解决这个困惑。
我的张量流版本为1.4.0
。
更新
添加一个最小的例子。第一个块有效,第二个块没有。
import tensorflow as tf
dataset_a = tf.data.Dataset.range(10)
dataset_b = tf.data.Dataset.range(20, 25)
input_handle = tf.placeholder(tf.string, shape=())
input_iterator = tf.data.Iterator.from_string_handle(
input_handle, dataset_a.output_types, dataset_a.output_shapes)
x = input_iterator.get_next()
plus_one = tf.add(x, 1)
with tf.Session() as sess:
iterator = dataset_b.make_initializable_iterator()
handle = sess.run(iterator.string_handle())
sess.run(iterator.initializer)
res = sess.run(plus_one, feed_dict={input_handle: handle})
print(res)
with tf.train.MonitoredTrainingSession() as sess:
iterator = dataset_a.make_initializable_iterator()
handle = sess.run(iterator.string_handle())
sess.run(iterator.initializer)
res = sess.run(plus_one, feed_dict={input_handle: handle})
print(res)
答案 0 :(得分:0)
我找到了问题的答案。我的想法是我不得不放弃使用句柄而是创建几个迭代器初始化器(每个数据集)。
解决方案如下:
import tensorflow as tf
dataset_a = tf.data.Dataset.range(10)
dataset_b = tf.data.Dataset.range(20, 25)
input_handle = tf.placeholder(tf.string, shape=(), name='input')
input_iterator = tf.data.Iterator.from_string_handle(
input_handle, dataset_a.output_types, dataset_a.output_shapes)
x = input_iterator.get_next()
plus_one = tf.add(x, 1)
with tf.Session() as sess:
iterator = dataset_b.make_initializable_iterator()
handle = sess.run(iterator.string_handle())
sess.run(iterator.initializer)
res = sess.run(plus_one, feed_dict={input_handle: handle})
print(res)
iterator = dataset_a.make_initializable_iterator()
iterator_init_op_a = iterator.make_initializer(dataset_a)
iterator_init_op_b = iterator.make_initializer(dataset_b)
x = iterator.get_next()
plus_one = tf.add(x, 1)
with tf.train.MonitoredTrainingSession() as sess:
sess.run(iterator_init_op_a)
res = sess.run(plus_one)
print(res)
res = sess.run(plus_one)
print(res)
sess.run(iterator_init_op_b)
res = sess.run(plus_one)
print(res)
res = sess.run(plus_one)
print(res)
如果数据集依赖于其他数据(例如,在this部分中),我可以在评估时将所需数据另外提供给iterator_init_op_a
。