从tf记录中读取数据时,Tensorflow在会话运行时冻结

时间:2017-08-17 10:35:55

标签: tensorflow

以下是代码:

import tensorflow as tf
import sys
from tensorflow.python.platform import gfile
import numpy as np
from scipy.misc import imread
import glob


with open("./labels_510.txt") as f:
    lines = list(f.readlines())
    labels = [str(w).replace("\n", "") for w in lines]

NCLASS = len(labels)
NCHANNEL = 3
WIDTH = 224
HEIGHT = 224

def getImageBatch(filenames, batch_size, capacity, min_after_dequeue):
    filenameQ = tf.train.string_input_producer(filenames, num_epochs=None)
    recordReader = tf.TFRecordReader()
    key, fullExample = recordReader.read(filenameQ)
    key_val = sess.run(key)
    print(key_val)
    features = tf.parse_single_example(
        fullExample,
        features={
            'image/height': tf.FixedLenFeature([], tf.int64),
            'image/width': tf.FixedLenFeature([], tf.int64),
            'image/colorspace': tf.FixedLenFeature([], dtype=tf.string, default_value=''),
            'image/channels': tf.FixedLenFeature([], tf.int64),
            'image/class/label': tf.FixedLenFeature([], tf.int64),
            'image/class/text': tf.FixedLenFeature([], dtype=tf.string, default_value=''),
            'image/format': tf.FixedLenFeature([], dtype=tf.string, default_value=''),
            'image/filename': tf.FixedLenFeature([], dtype=tf.string, default_value=''),
            'image/encoded': tf.FixedLenFeature([], dtype=tf.string, default_value='')
        })
    label = features['image/class/label']
    image_buffer = features['image/encoded']
    with tf.name_scope('decode_jpeg', [image_buffer], None):
        image = tf.image.decode_jpeg(image_buffer, channels=NCHANNEL)
        image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    image = tf.reshape(1 - tf.image.rgb_to_grayscale(image), [WIDTH * HEIGHT * NCHANNEL])
    label = tf.stack(tf.one_hot(label - 1, NCLASS))
    imageBatch, labelBatch = tf.train.shuffle_batch(
        [image, label], batch_size=batch_size,
        capacity=capacity,
        min_after_dequeue=min_after_dequeue)
    print(imageBatch.shape)
    print(labelBatch.shape)
    return imageBatch, labelBatch



with gfile.FastGFile("./output_graph_510.pb", 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    with tf.Session() as sess:
        sess.graph.as_default()
        tf.import_graph_def(graph_def)
        tf.global_variables_initializer().run()
        image_tensor, label_batch = getImageBatch(glob.glob("./images/tf_records/validation*"), 1, 10, 2)
        image_tensor = tf.reshape(image_tensor, (1, WIDTH, HEIGHT, NCHANNEL))
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        image_data = sess.run(image_tensor)
        # print(image_data.shape)
        # softmax_tensor = sess.graph.get_tensor_by_name('import/final_result:0')
        # predictions = sess.run(softmax_tensor, {'import/input:0': image_data})
        # predictions = np.squeeze(predictions)
        # print(predictions)
        coord.request_stop()
        coord.join(threads)

当我运行它时,会冻结以下消息:

2017-08-17 12:33:10.235086: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.1 instructions, but these are available on your machine and could speed up CPU computations.
2017-08-17 12:33:10.235099: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.2 instructions, but these are available on your machine and could speed up CPU computations.
2017-08-17 12:33:10.235101: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX instructions, but these are available on your machine and could speed up CPU computations.
2017-08-17 12:33:10.235104: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX2 instructions, but these are available on your machine and could speed up CPU computations.
2017-08-17 12:33:10.235106: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use FMA instructions, but these are available on your machine and could speed up CPU computations.
2017-08-17 12:33:10.322321: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:893] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2017-08-17 12:33:10.322510: I tensorflow/core/common_runtime/gpu/gpu_device.cc:940] Found device 0 with properties: 
name: GeForce GTX 1050
major: 6 minor: 1 memoryClockRate (GHz) 1.493
pciBusID 0000:01:00.0
Total memory: 3.95GiB
Free memory: 2.23GiB
2017-08-17 12:33:10.322519: I tensorflow/core/common_runtime/gpu/gpu_device.cc:961] DMA: 0 
2017-08-17 12:33:10.322522: I tensorflow/core/common_runtime/gpu/gpu_device.cc:971] 0:   Y 
2017-08-17 12:33:10.322529: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1030] Creating TensorFlow device (/gpu:0) -> (device: 0, name: GeForce GTX 1050, pci bus id: 0000:01:00.0)
  • Tensorflow版本:1.2.1
  • Ubuntu 16.04
  • GeForce GTX 1050

可在此处找到完整项目:https://github.com/kindlychung/demo-load-pb-tensorflow

1 个答案:

答案 0 :(得分:2)

因为你没有初始化与tf.train.shuffle_batch中使用的队列相关联的局部变量,所以它会冻结。局部变量通常是为enqueuedequeue等操作创建的临时变量,用于跟踪元素。

...
sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
image_data = sess.run(image_tensor)
print(image_data.shape)
...