TensorFlow:如何在string_input_producer中使用'num_epochs'

时间:2017-09-06 09:47:52

标签: python tensorflow

我无法在我的string_input_producer上启用纪元限制而不会收到OutOfRange错误(请求x,当前大小为0)。我要求的元素数量似乎并不重要,总有0个可用。

这是我的FileQueue构建器:

def get_queue(base_directory):
    files = [f for f in os.listdir(base_directory) if f.endswith('.bin')]
    shuffle(files)
    file = [os.path.join(base_directory, files[0])]
    fileQueue = tf.train.string_input_producer(file, shuffle=False, num_epochs=1)

    return fileQueue

如果我从string_input_producer中删除 num_epochs = 1 ,它可以创建样本。

我的输入管道:

def input_pipeline(instructions, fileQueue):
    example, label, feature_name_list = read_binary_format(fileQueue, instructions)

    num_preprocess_threads = 16
    capacity = 20

    example, label = tf.train.batch(
        [example, label],
        batch_size=50000,    # set the batch size way bigger so we always return the full amount of samples from the file
        allow_smaller_final_batch=True,
        capacity=capacity,
        num_threads=num_preprocess_threads)

    return example, label

最后我的会议:

with tf.Session(graph=tf.Graph()) as sess:
    train_inst_set = sf.DeserializationInstructions.from_filename(os.path.join(input_dir, "Train/config.json"))
    fileQueue = sf.get_queue(os.path.join(input_dir, "Train"))
    features_train, labels_train = sf.input_pipeline(train_inst_set, fileQueue)
    sess.run(tf.global_variables_initializer())

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)

    train_feature_batch, train_label_batch = sess.run([features_train, labels_train])

2 个答案:

答案 0 :(得分:1)

问题原因是:Issue #1045

无论出于何种原因,tf.global_variable_initialiser都不会初始化所有变量。您还需要初始化局部变量。

添加

sess.run(tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()))

到您的会话。

答案 1 :(得分:0)

来自TensorFlow Documentation

  

num_epochs:一个整数(可选)。如果指定,则为string_input_producer   之前从string_tensor num_epochs生成每个字符串   生成OutOfRange错误。如果没有指定,   string_input_producer可以循环遍历string_tensor中的字符串   无限次数

因此,一旦数据消耗 num_epochs 次,string_input_producer将启动 OutOfRange 例外

为避免突然停止,您有多种解决方案:

  • 使用更高级别的API tf.train.MonitoredSession,自动处理OutOfRange错误。

  • 将您的训练说明放在 try except 块中。捕获异常时,请协调员关闭线程。见example