调用sess.run()时,Tensorflow会挂起

时间:2017-07-31 20:28:33

标签: python session tensorflow

以下代码将挂起(只有 CTRL z 让我出局。)

import tensorflow as tf
import cifar10 # from https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10 (both cifar10.py & cifar10_input.py)

def main():
    print 'TensorFlow version: ',tf.__version__

    with tf.Session() as sess:

    with tf.device('/cpu:0'):
        images, labels = cifar10.distorted_inputs()

    input = tf.constant([[[1, 2, 3], [5, 5, 5]], [[4, 5, 6], [7, 7, 7]], [[7, 8, 9], [9, 9, 9]]])

    one=input[0]
    print "X1 ",type(input), one
    oneval = sess.run(one)
    print "X2 ",type(one), one, type(oneval), oneval

    two=images[0]
    print "Y1 ",type(images), two
    twoval = sess.run(two)
    print "Y2 ",type(two), two, type(twoval), twoval

main()

我得到以下输出(使用Python 2.7.5):

[gpu@centos-7-4 demo]$ python demo.py
TensorFlow version:  1.2.1
2017-07-31 16:06:45.503157: W tensorflow/core/platform/cpu_feature_guard.cc:45] >The TensorFlow library wasn't compiled to use SSE4.1 instructions, but these are available on your machine and could speed up CPU computations.
2017-07-31 16:06:45.503182: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.2 instructions, but these are available on your machine and could speed up CPU computations.
2017-07-31 16:06:45.503187: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX instructions, but these are available on your machine and could speed up CPU computations.
Filling queue with 20000 CIFAR images before starting to train. This will take a few minutes.
X1  class 'tensorflow.python.framework.ops.Tensor'> Tensor("strided_slice:0", shape=(2, 3), dtype=int32)
X2  class 'tensorflow.python.framework.ops.Tensor'> Tensor("strided_slice:0", shape=(2, 3), dtype=int32) <type 'numpy.ndarray'> [[1 2 3] [5 5 5]]
Y1  class 'tensorflow.python.framework.ops.Tensor'> Tensor("strided_slice_1:0", shape=(24, 24, 3), dtype=float32)
^Z

任何人都有任何建议(或解决方案)?

如果有人对后台感兴趣,我的最终目标是将distorted_inputs()返回的张量转换为一组JSON对象。因此,天真的计划是迭代图像的每个元素并提取值。

2 个答案:

答案 0 :(得分:6)

我在这里找到了答案Printing tensorflow tensor in Python hangs forever

Key是两行:

coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

答案 1 :(得分:0)

我在TF 1.7中遇到了这个问题,降级到1.3,它运行正常。