tensorflow tf.train.string_input_producer和shuffle_batch问题

时间:2018-01-04 07:25:44

标签: python tensorflow

没有num_epochs的代码运作良好

但添加num_epochs时出错

OutOfRangeError (see above for traceback): RandomShuffleQueue '_1_shuffle_batch/random_shuffle_queue' is closed and has insufficient elements (requested 2, current size 0)
[[Node: shuffle_batch = QueueDequeueManyV2[component_types=[DT_FLOAT, DT_INT32], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](shuffle_batch/random_shuffle_queue, shuffle_batch/n)]]

我一直在关注Tensorflow官方教程,无法完成num_epochs

我想要做的是在通过epoch_num时生成错误,所以我不必通过计算我的整个训练文件的实例#来跟踪当前的batch_num和max_batch_num

任何想法为什么?我想我做错了什么

""" Some people tried to use TextLineReader for the assignment 1
but seem to have problems getting it work, so here is a short 
script demonstrating the use of CSV reader on the heart dataset.
Note that the heart dataset is originally in txt so I first
converted it to csv to take advantage of the already laid out columns.
You can download heart.csv in the data folder.
Author: Chip Huyen
Prepared for the class CS 20SI: "TensorFlow for Deep Learning Research"
cs20si.stanford.edu
"""
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'

import sys
sys.path.append('..')

import tensorflow as tf

DATA_PATH = './heart.csv'
BATCH_SIZE = 2
N_FEATURES = 9

def batch_generator(filenames):
    """ filenames is the list of files you want to read from. 
    In this case, it contains only heart.csv
    """
    filename_queue = tf.train.string_input_producer(filenames, num_epochs=3)
    reader = tf.TextLineReader(skip_header_lines=1) # skip the first line in the file
    _, value = reader.read(filename_queue)

    # record_defaults are the default values in case some of our columns are empty
    # This is also to tell tensorflow the format of our data (the type of the decode result)
    # for this dataset, out of 9 feature columns, 
    # 8 of them are floats (some are integers, but to make our features homogenous, 
    # we consider them floats), and 1 is string (at position 5)
    # the last column corresponds to the label is an integer

    record_defaults = [[1.0] for _ in range(N_FEATURES)]
    record_defaults[4] = ['']
    record_defaults.append([1])

    # read in the 10 rows of data
    content = tf.decode_csv(value, record_defaults=record_defaults) 

    # convert the 5th column (present/absent) to the binary value 0 and 1
    content[4] = tf.cond(tf.equal(content[4], tf.constant('Present')), lambda: tf.constant(1.0), lambda: tf.constant(0.0))

    # pack all 9 features into a tensor
    features = tf.stack(content[:N_FEATURES])

    # assign the last column to label
    label = content[-1]

    # minimum number elements in the queue after a dequeue, used to ensure 
    # that the samples are sufficiently mixed
    # I think 10 times the BATCH_SIZE is sufficient
    min_after_dequeue = 10 * BATCH_SIZE

    # the maximum number of elements in the queue
    capacity = 20 * BATCH_SIZE

    # shuffle the data to generate BATCH_SIZE sample pairs
    data_batch, label_batch = tf.train.shuffle_batch([features, label], batch_size=BATCH_SIZE, 
                                        capacity=capacity, min_after_dequeue=min_after_dequeue)

    return data_batch, label_batch

def generate_batches(data_batch, label_batch):
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        for _ in range(400): # generate 400 batches
            features, labels = sess.run([data_batch, label_batch])
            print(features)
        coord.request_stop()
        coord.join(threads)

def main():
    data_batch, label_batch = batch_generator([DATA_PATH])
    generate_batches(data_batch, label_batch)

if __name__ == '__main__':
    main()

1 个答案:

答案 0 :(得分:1)

tf.train.string_input_producer()使用"本地变量"在其实现中,所以你需要添加

sess.run(tf.local_variables_initializer()) 

...在启动队列跑步者之前。

出于可用性原因,我们现在鼓励TensorFlow用户使用tf.data API构建输入管道。您的代码可以按如下方式重写:

# Start with a dataset of filenames.
dataset = tf.data.Dataset.from_tensor_slices(filenames)

# Repeat the filenames for three epochs.
dataset = dataset.repeat(3)

# Use Dataset.flat_map() and tf.data.TextLineDataset to convert the
# filenames into a dataset of lines.
dataset = dataset.flat_map(
    lambda filename: tf.data.TextLineDataset(filename).skip(1))

# Wrap the per-line parsing logic in a function, and map it over the dataset.
def parse_line(value):
    record_defaults = [[1.0] for _ in range(N_FEATURES)]
    record_defaults[4] = ['']
    record_defaults.append([1])

    # read in the 10 rows of data
    content = tf.decode_csv(value, record_defaults=record_defaults) 

    # convert the 5th column (present/absent) to the binary value 0 and 1
    content[4] = tf.cond(tf.equal(content[4], tf.constant('Present')), lambda: tf.constant(1.0), lambda: tf.constant(0.0))

    # pack all 9 features into a tensor
    features = tf.stack(content[:N_FEATURES])

    # assign the last column to label
    label = content[-1]

    return features, label

dataset = dataset.map(parse_line)

# Shuffle the dataset.
dataset = dataset.shuffle(20 * BATCH_SIZE)

# Combine consecutive elements into batches.
dataset = dataset.batch(BATCH_SIZE)

# Create an iterator to get elements from the dataset.
iterator = dataset.make_one_shot_iterator()

# Get tensors that represent the next element of the iterator.
data_batch, label_batch = iterator.get_next()

# Finally, create a session to iterate over the batches.
with tf.Session() as sess:
    try:
        while True:
            features, labels = sess.run([data_batch, label_batch])
            print(features)
    except tf.errors.OutOfRangeError:
        # Raised when there are no more batches to produce.
        pass