使用tensorflow的shuffle_batch方法

时间:2016-07-15 15:52:13

标签: python tensorflow

我正在尝试tensorflow,我正在尝试从csv文件中读取并通过shuffle_batch打印出一批数据。我已经抛弃了decode_csv docsshuffle_batch docs,但我仍然无法让它发挥作用。

这就是我所拥有的: 导入tensorflow为tf

sess = tf.InteractiveSession()

filename_queue = tf.train.string_input_producer(
    ["./data/train.csv"], num_epochs=1, shuffle=True) # total record count in csv is 30K
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)

record_defaults = [["1"], ["2"]] # irrelevant for this discussion
input, outcome = tf.decode_csv(value, record_defaults=record_defaults)

min_after_dequeue = 1000
batch_size = 10
capacity = min_after_dequeue + 3 * batch_size

example_batch = tf.train.shuffle_batch([outcome], batch_size, capacity, min_after_dequeue)
coord = tf.train.Coordinator()
tf.train.start_queue_runners(sess, coord=coord)
example_batch.eval(session = sess)

运行此命令将生成此异常:

OutOfRangeError: RandomShuffleQueue
    '_3_shuffle_batch_1/random_shuffle_queue' is closed 
    and has insufficient elements (requested 10, current size 0)

我不确定问题是什么。我有一种感觉,这是由于会议和我处理它的方式;我可能没有正确地做到这一点。

2 个答案:

答案 0 :(得分:0)

尝试从num_epochs=1初始化程序中删除string_input_producer

答案 1 :(得分:-2)

"注意:如果num_epochs不是None,则此函数创建本地计数器纪元。使用local_variables_initializer()初始化局部变量。"见:https://www.tensorflow.org/api_docs/python/tf/train/string_input_producer