为什么字符串类型的tf.placeholder不适用于tf.string_input_producer()

时间:2017-10-11 00:50:16

标签: python tensorflow

在我想使用占位符动态将输入文件名更改为文件名队列的场景中,我可以遍历文件。但是我发现以下代码不起作用,有人有想法吗?

import tensorflow as tf

def test(s):
    filename_queue = tf.train.string_input_producer([s])

    reader = tf.TextLineReader()
    key, value = reader.read(filename_queue)

    record_defaults = [[1.0], [1]]
    col1, col2 = tf.decode_csv(value, record_defaults = record_defaults)

    return col1, col2

s = tf.placeholder(tf.string, None, name = 's')
# s = tf.constant('file0.csv', tf.string)
ss = ["file0.csv", "file1.csv"]
inputs, labels = test(s)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    for e in ss:
        inputs_val, labels_val = sess.run([inputs, labels], feed_dict = {s: e})
        print("input {} - label {}".format(inputs_val, labels_val))

    coord.request_stop()
    coord.join(threads)

感谢您的帮助。

(tensorflow)[yuming @ atlas1 working-files] $ python 36.py

2017-10-11 11:28:40.825044: I tensorflow/core/common_runtime/gpu/gpu_device.cc:965] Found device 0 with properties:
name: Quadro M4000 major: 5 minor: 2 memoryClockRate(GHz): 0.7725
pciBusID: 0000:83:00.0
totalMemory: 7.93GiB freeMemory: 7.87GiB
2017-10-11 11:28:40.931938: I tensorflow/core/common_runtime/gpu/gpu_device.cc:965] Found device 1 with properties:
name: Quadro K2200 major: 5 minor: 0 memoryClockRate(GHz): 1.124
pciBusID: 0000:03:00.0
totalMemory: 3.95GiB freeMemory: 3.47GiB
2017-10-11 11:28:40.931990: I tensorflow/core/common_runtime/gpu/gpu_device.cc:980] Device peer to peer matrix
2017-10-11 11:28:40.931998: I tensorflow/core/common_runtime/gpu/gpu_device.cc:986] DMA: 0 1
2017-10-11 11:28:40.932002: I tensorflow/core/common_runtime/gpu/gpu_device.cc:996] 0:   Y N
2017-10-11 11:28:40.932005: I tensorflow/core/common_runtime/gpu/gpu_device.cc:996] 1:   N Y
2017-10-11 11:28:40.932013: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1055] Creating TensorFlow device (/device:GPU:0) -> (device: 0, name: Quadro M4000, pci bus id: 0000:83:00.0, compute capability: 5.2)
2017-10-11 11:28:40.932018: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1042] Ignoring gpu device (device: 1, name: Quadro K2200, pci bus id: 0000:03:00.0, compute capability: 5.0) with Cuda multiprocessor count: 5. The minimum required count is 8. You can adjust this requirement with the env var TF_MIN_GPU_MULTIPROCESSOR_COUNT.
Traceback (most recent call last):
  File "36.py", line 26, in <module>
    inputs_val, labels_val = sess.run([inputs, labels], feed_dict = {s: 'file0.csv'})
  File "/home/yuming/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 889, in run
    run_metadata_ptr)
  File "/home/yuming/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1118, in _run
    feed_dict_tensor, options, run_metadata)
  File "/home/yuming/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1315, in _do_run
    options, run_metadata)
  File "/home/yuming/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1334, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.OutOfRangeError: FIFOQueue '_0_input_producer' is closed and has insufficient elements (requested 1, current size 0)
         [[Node: ReaderReadV2 = ReaderReadV2[_device="/job:localhost/replica:0/task:0/cpu:0"](TextLineReaderV2, input_producer)]]

Caused by op u'ReaderReadV2', defined at:
  File "36.py", line 17, in <module>
    inputs, labels = test(s)
  File "36.py", line 7, in test
    key, value = reader.read(filename_queue)
  File "/home/yuming/tensorflow/lib/python2.7/site-packages/tensorflow/python/ops/io_ops.py", line 194, in read
    return gen_io_ops._reader_read_v2(self._reader_ref, queue_ref, name=name)
  File "/home/yuming/tensorflow/lib/python2.7/site-packages/tensorflow/python/ops/gen_io_ops.py", line 654, in _reader_read_v2
    queue_handle=queue_handle, name=name)
  File "/home/yuming/tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 789, in _apply_op_helper
    op_def=op_def)
  File "/home/yuming/tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 3052, in create_op
    op_def=op_def)
  File "/home/yuming/tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1610, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

OutOfRangeError (see above for traceback): FIFOQueue '_0_input_producer' is closed and has insufficient elements (requested 1, current size 0)
         [[Node: ReaderReadV2 = ReaderReadV2[_device="/job:localhost/replica:0/task:0/cpu:0"](TextLineReaderV2, input_producer)]]

1 个答案:

答案 0 :(得分:0)

我认为你不能用tf.train.string_input_producer做到这一点。一种可能的选择是直接使用tf.FIFOQueue(也由引擎盖下的string_input_producer使用),并手动填写文件名。

filename_queue = tf.FIFOQueue(capacity=100, dtypes=[tf.string])
with tf.Session() as session:
  reader = tf.TextLineReader()
  key, value = reader.read(filename_queue)
  col1, col2, col3, col4, target = tf.decode_csv(value, record_defaults=[[1.], [1.], [1.], [1.], [1.]])
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(coord=coord)

  session.run(filename_queue.enqueue("file0.csv"))
  for i in range(10):
    print session.run(target)

  session.run(filename_queue.enqueue("file1.csv"))
  for i in range(10):
    print session.run(target)

  # NOTE: if I call this one more time, it'll hang, because
  # the queue is empty and the last CSV is fully read
  #
  # print session.run(target)

  coord.request_stop()
  coord.join(threads)

我的两个CSV文件都有10行5列,因此我使用range(10)并按预期工作:首先file0.csv,然后是file1.csv

小心:如果您没有提供足够的示例,主线程将挂起

我建议您始终保持队列不为空并继续向其添加文件。这样,您就可以以任何顺序动态地提供队列。