在tensorflow中使用带有MonitoredTrainingSession的数据集时无法进行馈送

时间:2017-10-05 10:05:08

标签: python tensorflow

我的tensorflow 1.3.0以下代码。您可以毫无错误地运行它。 但是,如果取消注释tf.summary.scalar('test', next_batch)checkpoint_dir='/temp'

我得到了

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'Placeholder' with dtype string [[Node: Placeholder = Placeholder[dtype=DT_STRING, shape=[], _device="/job:localhost/replica:0/task:0/cpu:0"]()]]

出了什么问题?

import tensorflow as tf

with tf.Graph().as_default():
    global_step = tf.contrib.framework.get_or_create_global_step()
    dataset_train = tf.contrib.data.Dataset.range(10)
    dataset_val = tf.contrib.data.Dataset.range(90, 100)

    iter_train_handle = dataset_train.make_one_shot_iterator().string_handle()
    iter_val_handle = dataset_val.make_one_shot_iterator().string_handle()

    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.contrib.data.Iterator.from_string_handle(
        handle, dataset_train.output_types, dataset_train.output_shapes)
    next_batch = iterator.get_next()
    # tf.summary.scalar('test', next_batch)

    with tf.train.MonitoredTrainingSession(
            # checkpoint_dir='/temp',
    ) as sess:
        handle_train, handle_val = sess.run([iter_train_handle, iter_val_handle])

        for step in range(10):
            print('train', sess.run(next_batch, feed_dict={handle: handle_train}))

            if step % 3 == 0:
                print('val', sess.run(next_batch, feed_dict={handle: handle_val}))

2 个答案:

答案 0 :(得分:1)

如果您仍要使用MonitoredTrainingSession,则解决此问题的方法是使用tf.placeholder_with_default。使用合理的初始默认值,然后根据需要对其进行更新。

答案 1 :(得分:0)

从github获得答案。来自@asimshankar。

  

也就是说,MonitoredTrainingSession运行摘要操作   由于您的摘要操作已连接到占位符,因此它是   抱怨缺少饲料。

     

如果您不想要MonitoredTrainingSession的功能,   我建议您直接使用MonitoredSessionSession

     

关闭它,因为它不是错误或功能请求   (MonitoredTrainingSession故意尝试运行摘要   定期进行操作。)