我正在尝试使用tensorflow读取cifar10输入数据。我知道在tensorflow / models github repo下存在cifar10.py。运行以下代码时,它基本上卡住了,并且在屏幕上显示没有结果。帮助表示感谢。
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import numpy as np
import tensorflow as tf
from tensorflow.contrib import learn
from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib
tf.logging.set_verbosity(tf.logging.INFO)
def read_input(filename_queue):
label_bytes = 1
record_bytes = label_bytes + (32 * 32 * 3)
reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
key, value = reader.read(filename_queue)
single_row_bytes = tf.decode_raw(value, tf.uint8)
label = tf.cast(tf.strided_slice(single_row_bytes, [0], [1]), tf.int32)
image = tf.transpose(tf.reshape(tf.strided_slice(single_row_bytes, [1], [3073]), [3, 32, 32]), [1,2,0])
return label, image
def generate_image_and_label_batch(label, image, batch_size, num_preprocess_threads, min_queue_examples):
images, label_batch = tf.train.batch(
[image, label],
batch_size=batch_size,
num_threads=num_preprocess_threads,
capacity=min_queue_examples + 3 * batch_size)
return images, tf.reshape(label_batch, [batch_size])
def mainFunction(file_names):
filename_queue = tf.train.string_input_producer(file_names)
label, image = read_input(filename_queue)
image.set_shape([32, 32, 3])
label.set_shape([1])
return generate_image_and_label_batch(label, image, 10, 2, 4000)
def main(unused_argv):
with tf.Graph().as_default():
with tf.Session() as sess:
x, y = sess.run(mainFunction(["cifar-10-batches-py/data_batch_" + str(i) for i in range(1,2)]))
print(x)
print(y)
if __name__ == "__main__":
tf.app.run()