Tensorflow:如何在文件结束前读取记录

时间:2017-09-06 08:24:38

标签: python tensorflow

我有一个用于训练RNN的数据集,其中样本序列包含在各个文件中。

timeSeries1.bin
timeSeries2.bin

在每个 timeseries.bin 中有不同数量的样本。当文件名存在于 tf.train.string_input_producer 中时,如何从一个时间序列加载所有样本?我需要能够丢弃序列之间的RNN状态,这意味着我需要知道何时到达序列的末尾。

这是我的输入管道功能:<​​/ p>

def input_pipeline(instructions, base_directory):

    files = [f for f in os.listdir(base_directory) if f.endswith('.bin')]

    filename_list = [os.path.join(base_directory, x) for x in files]
    filename_queue = tf.train.string_input_producer(
        filename_list, shuffle=True, capacity=100)
    example, label, feature_name_list = read_binary_format(filename_queue, instructions)

    num_preprocess_threads = 16
    capacity = 20

    example, label = tf.train.batch(
        [example, label],
        batch_size=700,
        capacity=capacity,
        num_threads=num_preprocess_threads)

    return example, label

我知道我需要根据当前 filename_queue 中文件的大小来改变我的批量大小,但我不知道如何做到这一点。

0 个答案:

没有答案