我已将CSV文件(“test03.txt”)转换为TFRecords格式的文件(“test03.tfrecords”),但是当我读入TFRecords文件并尝试使用tf.train.shuffle_batch时,我得到了错误消息
RandomShuffleQueue '_2_shuffle_batch_1/random_shuffle_queue' is closed and has insufficient elements (requested 10, current size 0)
CSV文件是
1,0
2,0
3,0
4,0
5,1
6,0
7,1
8,1
9,1
10,1
我使用
转换为TFRecords文件import pandas
import tensorflow as tf
csv = pandas.read_csv(r"test03.txt", header=None).values
with tf.python_io.TFRecordWriter("test03.tfrecords") as writer:
for row in csv:
features, label = row[:-1], row[-1]
example = tf.train.Example()
example.features.feature["features"].float_list.value.extend(features)
example.features.feature["label"].int64_list.value.append(label)
writer.write(example.SerializeToString())
但是当我运行以下代码时,我收到上述错误消息:
import tensorflow as tf
batch_size = 10
with tf.Session() as sess:
filename_queue = tf.train.string_input_producer(["test03.tfrecords"],num_epochs=1)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
feature_dict = {'features': tf.FixedLenFeature([], tf.int64),'label': tf.FixedLenFeature([], tf.int64)}
featuresLabel = tf.parse_single_example(serialized_example, features=feature_dict)
xdata = tf.cast(featuresLabel['features'], tf.int32)
label = tf.cast(featuresLabel['label'], tf.int32)
min_after_dequeue = 1
capacity = min_after_dequeue + 3 * batch_size
batch_of_xs, batch_of_labels = tf.train.shuffle_batch([xdata, label], batch_size=batch_size, capacity=capacity, num_threads=1, min_after_dequeue=min_after_dequeue)
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
single_batch_xs, single_batch_ys = sess.run([batch_of_xs, batch_of_labels])
答案 0 :(得分:0)
您的问题位于feature_dict中。在您的初始示例中,您执行转换为TFRecords,如下所示:
example.features.feature["features"].float_list.value.extend(features)
example.features.feature["label"].int64_list.value.append(label)
因此,您的功能被编码为浮点数,您的标签编码为int64。但是当你把它们读回去时,你将它们变成int64:
feature_dict = {'features': tf.FixedLenFeature([], tf.int64),'label': tf.FixedLenFeature([], tf.int64)}
您的问题就像将feature_dict与初始编码匹配一样简单,因此将上面的行更改为:
feature_dict = {'features': tf.FixedLenFeature([], tf.float32),'label': tf.FixedLenFeature([], tf.int64)}
为我解决了这个问题(以及最后的single_batch_xs和ys的打印)。