如何在Tensorflow中读取一个文件?

时间:2017-03-23 21:13:49

标签: python tensorflow queue

Tensorflow中有读取文件的函数,但这些函数接受文件名队列。

这意味着,当我从文件本身读取文件时,我有义务从外部推断出标签。

不幸的是,我在内存中有一个元组列表,其中每个元组由文件名和标签组成。即标签不在文件中,而是在内存中。

是否有可能以某种方式创建两个同步队列或以其他方式从不同来源获取数据和标签?

更新

我写了这样的东西,但失败了

data = [[os.path.join(corpus_dir,filename),label] for(filename,label)in data]

def read_my_file():

    records = tf.train.input_producer(data)
    record = records.dequeue()
    filename = record[0]
    filenames = tf.FIFOQueue(1, tf.string)
    filenames.enqueue(filename)
    label = record[1]
    reader = tf.WholeFileReader()
    key, raw = reader.read(filenames)
    image = tf.image.decode_png(raw)
    return image, label


image, label = read_my_file()

init_op = tf.initialize_all_variables()
with tf.Session() as sess:
    sess.run(init_op)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    for i in range(10):
        image1, label1 = sess.run(image, label)
        print(label1)

此处data是Python内存中的元组列表,而filenames是我组织为文件阅读器提供的队列。

看起来很糟糕但不起作用:

...test05.py", line 37, in <module>
    image1, label1 = sess.run(image, label)
  File "C:\Python35\lib\site-packages\tensorflow\python\client\session.py", line 769, in run
    run_metadata_ptr)
  File "C:\Python35\lib\site-packages\tensorflow\python\client\session.py", line 915, in _run
    if feed_dict:
  File "C:\Python35\lib\site-packages\tensorflow\python\framework\ops.py", line 525, in __bool__
    raise TypeError("Using a `tf.Tensor` as a Python `bool` is not allowed. "
TypeError: Using a `tf.Tensor` as a Python `bool` is not allowed. Use `if t is not None:` instead of `if t:` to test if a tensor is defined, and use TensorFlow ops such as tf.cond to execute subgraphs conditioned on the value of a tensor.

如你所见,我无处使用条件。

1 个答案:

答案 0 :(得分:4)

由于您使用的是tf.WholeFileReader,因此可以通过将其替换为更简单的tf.read_file()操作来避免同步多个队列的问题,如下所示:

def read_my_file():
    records = tf.train.input_producer(data)
    filename, label = records.dequeue()
    raw = tf.read_file(filename)
    image = tf.image.decode_png(raw)
    return image, label