TensoFlow:当train.string_input_producer()中指定num_epochs时,tf.train.shuffle_batch()总是抛出OutOfRange

时间:2017-07-20 03:52:44

标签: tensorflow

我编写了一个简单的程序来尝试在TensorFlow中的批处理函数中读取数据,但遇到了一个问题:

我创建了6个简单的csv文件;每个文件包含3条记录,如:

1.0,1.0,1.0,1.0,1
1.1,1.1,1.1,1.1,1
1.2,1.2,1.2,1.2,1

(前4列是功能,第5列是标签。) 所以共有6个文件有6 * 3 = 18个记录。

我尝试使用readerbatchshuffle_batch将这些文件分成3批6条记录/批次。如果我在num_epochs中未指定string_input_producer,则代码可以正常运行。但是,当我指定num_epochs时,batchshuffle_batch始终会引发OutOfRange errorcurrent_size始终为零......

以下是代码:

import tensorflow as tf
import os

csvFiles = os.listdir('./data')
csvFiles = [i for i in csvFiles if i[-4:]=='.csv' ]
csvFiles = ['./data/'+i for i in csvFiles]

print(csvFiles)

fileQ = tf.train.string_input_producer(csvFiles,shuffle=False,num_epochs=3)
reader = tf.TextLineReader()
key,value = reader.read(fileQ)
record_defaults = [[0.0], [0.0], [0.0], [0.0], [0]]
col1, col2, col3, col4, label = tf.decode_csv(value, record_defaults=record_defaults)
feature = tf.stack([col1, col2, col3, col4])
feature_batch, label_batch = tf.train.shuffle_batch([feature, label], batch_size=6, capacity=100, min_after_dequeue=1) # num_threads=3,

with tf.Session() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess,coord=coord)

    try:
        for i in range(3):
            featureBatch, labelBatch = sess.run([feature_batch, label_batch])
            print(featureBatch)
            print(labelBatch)
    except tf.errors.OutOfRangeError:    
        print("Done reading!")
    finally:
        coord.request_stop()

coord.join(threads)
print("**END**")

OutOfRange错误的输出为here

请注意,首次调用shuffle_batch时会抛出错误。我认为这意味着不能读取单个记录。

甚至我改变了代码只读了一条记录,它抛出了同样的错误: l,f=sess.run([label,feature])

这是一个非常简单的代码。不知道它有什么问题吗?非常感谢你!

2 个答案:

答案 0 :(得分:0)

这在方法的字符串doc中解释:

  

num_epochs:一个整数(可选)。 如果已指定,string_input_producer     之前会生成string_tensor num_epochs次的每个字符串     生成OutOfRange错误。如果未指定,     string_input_producer可以循环浏览string_tensor中的字符串     无限次。

OutOfRange错误基本上重现了迭代列表时Python引发的StopIteration错误。例如,请参阅此answer

答案 1 :(得分:0)

阅读完其他示例代码后,我发现需要添加:    tf.local_variables_initializer()。运行() 初始化变量。 (即使我不知道为什么num_Epochs = 3需要初始化。)

现在代码可以工作了。