读取TFRecords时,RandomShuffleQueue将关闭

时间:2017-07-22 23:39:07

标签: tensorflow shuffle

我已将CSV文件(“test03.txt”)转换为TFRecords格式的文件(“test03.tfrecords”),但是当我读入TFRecords文件并尝试使用tf.train.shuffle_batch时,我得到了错误消息

RandomShuffleQueue '_2_shuffle_batch_1/random_shuffle_queue' is closed and has insufficient elements (requested 10, current size 0)

CSV文件是

1,0
2,0
3,0
4,0
5,1
6,0
7,1
8,1
9,1
10,1

我使用

转换为TFRecords文件
import pandas
import tensorflow as tf 

csv = pandas.read_csv(r"test03.txt", header=None).values
with tf.python_io.TFRecordWriter("test03.tfrecords") as writer:
   for row in csv:
      features, label = row[:-1], row[-1]
      example = tf.train.Example()
      example.features.feature["features"].float_list.value.extend(features)
      example.features.feature["label"].int64_list.value.append(label)
      writer.write(example.SerializeToString())

但是当我运行以下代码时,我收到上述错误消息:

import tensorflow as tf

batch_size = 10 

with tf.Session() as sess:
   filename_queue = tf.train.string_input_producer(["test03.tfrecords"],num_epochs=1)
   reader = tf.TFRecordReader()
   _, serialized_example = reader.read(filename_queue)

   feature_dict = {'features': tf.FixedLenFeature([], tf.int64),'label': tf.FixedLenFeature([], tf.int64)}
   featuresLabel = tf.parse_single_example(serialized_example, features=feature_dict)
   xdata = tf.cast(featuresLabel['features'], tf.int32)
   label = tf.cast(featuresLabel['label'], tf.int32)

   min_after_dequeue = 1
   capacity = min_after_dequeue + 3 * batch_size
   batch_of_xs, batch_of_labels = tf.train.shuffle_batch([xdata, label], batch_size=batch_size, capacity=capacity, num_threads=1, min_after_dequeue=min_after_dequeue)

   init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
   sess.run(init_op)

   coord = tf.train.Coordinator()
   threads = tf.train.start_queue_runners(coord=coord)

   single_batch_xs, single_batch_ys = sess.run([batch_of_xs, batch_of_labels])

1 个答案:

答案 0 :(得分:0)

您的问题位于feature_dict中。在您的初始示例中,您执行转换为TFRecords,如下所示:

example.features.feature["features"].float_list.value.extend(features)
example.features.feature["label"].int64_list.value.append(label)

因此,您的功能被编码为浮点数,您的标签编码为int64。但是当你把它们读回去时,你将它们变成int64:

   feature_dict = {'features': tf.FixedLenFeature([], tf.int64),'label': tf.FixedLenFeature([], tf.int64)}

您的问题就像将feature_dict与初始编码匹配一样简单,因此将上面的行更改为:

feature_dict = {'features': tf.FixedLenFeature([], tf.float32),'label': tf.FixedLenFeature([], tf.int64)}

为我解决了这个问题(以及最后的single_batch_xs和ys的打印)。