我编写了一个简单的程序来尝试在TensorFlow中的批处理函数中读取数据,但遇到了一个问题:
我创建了6个简单的csv文件;每个文件包含3条记录,如:
1.0,1.0,1.0,1.0,1
1.1,1.1,1.1,1.1,1
1.2,1.2,1.2,1.2,1
(前4列是功能,第5列是标签。) 所以共有6个文件有6 * 3 = 18个记录。
我尝试使用reader
,batch
或shuffle_batch
将这些文件分成3批6条记录/批次。如果我在num_epochs
中未指定string_input_producer
,则代码可以正常运行。但是,当我指定num_epochs
时,batch
或shuffle_batch
始终会引发OutOfRange error
。 current_size
始终为零......
以下是代码:
import tensorflow as tf
import os
csvFiles = os.listdir('./data')
csvFiles = [i for i in csvFiles if i[-4:]=='.csv' ]
csvFiles = ['./data/'+i for i in csvFiles]
print(csvFiles)
fileQ = tf.train.string_input_producer(csvFiles,shuffle=False,num_epochs=3)
reader = tf.TextLineReader()
key,value = reader.read(fileQ)
record_defaults = [[0.0], [0.0], [0.0], [0.0], [0]]
col1, col2, col3, col4, label = tf.decode_csv(value, record_defaults=record_defaults)
feature = tf.stack([col1, col2, col3, col4])
feature_batch, label_batch = tf.train.shuffle_batch([feature, label], batch_size=6, capacity=100, min_after_dequeue=1) # num_threads=3,
with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord=coord)
try:
for i in range(3):
featureBatch, labelBatch = sess.run([feature_batch, label_batch])
print(featureBatch)
print(labelBatch)
except tf.errors.OutOfRangeError:
print("Done reading!")
finally:
coord.request_stop()
coord.join(threads)
print("**END**")
OutOfRange
错误的输出为here
请注意,首次调用shuffle_batch
时会抛出错误。我认为这意味着不能读取单个记录。
甚至我改变了代码只读了一条记录,它抛出了同样的错误:
l,f=sess.run([label,feature])
这是一个非常简单的代码。不知道它有什么问题吗?非常感谢你!
答案 0 :(得分:0)
这在方法的字符串doc中解释:
num_epochs:一个整数(可选)。 如果已指定,
string_input_producer
之前会生成string_tensor
num_epochs
次的每个字符串 生成OutOfRange
错误。如果未指定,string_input_producer
可以循环浏览string_tensor
中的字符串 无限次。
OutOfRange
错误基本上重现了迭代列表时Python引发的StopIteration
错误。例如,请参阅此answer。
答案 1 :(得分:0)
阅读完其他示例代码后,我发现需要添加: tf.local_variables_initializer()。运行() 初始化变量。 (即使我不知道为什么num_Epochs = 3需要初始化。)
现在代码可以工作了。