tensorflow cifar-10读取数据卡住了

时间:2017-06-07 07:44:05

标签: python machine-learning tensorflow tensor

我正在尝试使用tensorflow读取cifar10输入数据。我知道在tensorflow / models github repo下存在cifar10.py。运行以下代码时,它基本上卡住了,并且在屏幕上显示没有结果。帮助表示感谢。

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import time

import numpy as np
import tensorflow as tf

from tensorflow.contrib import learn
from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib
tf.logging.set_verbosity(tf.logging.INFO)

def read_input(filename_queue):
    label_bytes = 1
    record_bytes = label_bytes  + (32 * 32 * 3)
    reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
    key, value = reader.read(filename_queue)
    single_row_bytes = tf.decode_raw(value, tf.uint8)
    label = tf.cast(tf.strided_slice(single_row_bytes, [0], [1]), tf.int32)
    image = tf.transpose(tf.reshape(tf.strided_slice(single_row_bytes, [1], [3073]), [3, 32, 32]), [1,2,0])
    return label, image

def generate_image_and_label_batch(label, image, batch_size, num_preprocess_threads, min_queue_examples):
    images, label_batch = tf.train.batch(
        [image, label],
        batch_size=batch_size,
        num_threads=num_preprocess_threads,
        capacity=min_queue_examples + 3 * batch_size)

    return images, tf.reshape(label_batch, [batch_size])

def mainFunction(file_names):
    filename_queue = tf.train.string_input_producer(file_names)
    label, image = read_input(filename_queue)
    image.set_shape([32, 32, 3])
    label.set_shape([1])
    return generate_image_and_label_batch(label, image, 10, 2, 4000)

def main(unused_argv):
    with tf.Graph().as_default():
        with tf.Session() as sess:
            x, y = sess.run(mainFunction(["cifar-10-batches-py/data_batch_" + str(i) for i in range(1,2)]))
            print(x)
            print(y)

if __name__ == "__main__":
    tf.app.run()

1 个答案:

答案 0 :(得分:0)

好像你没有启动queuerunners /初始化变量。 我对this question的回答应该有所帮助。