使用numpy_input_fn时挂起

时间:2018-03-14 19:09:14

标签: python tensorflow

我试图了解如何使用输入函数和估算器。当我尝试使用numpy输入函数时,tensorflow会挂起。我在这里做错了吗?我在MacOS 10.13上使用Python 3.6.4运行TensorFlow 1.6.0。

以下是挂起的代码示例:

import tensorflow as tf
import numpy as np

a = np.array([1,2,3,4,5])

infun = tf.estimator.inputs.numpy_input_fn(
      x={"x": a},
      batch_size=2,
      num_epochs=3,
      shuffle=False)

batch = infun()

with tf.Session() as sess:
    print(sess.run(batch))

任何帮助将不胜感激。谢谢!

1 个答案:

答案 0 :(得分:0)

tf.estimator.inputs.numpy_input_fn方法挂起,因为它需要队列运行器来提供输入。队列运行程序通常由tf.estimator方法处理,因此要在您必须设置并启动队列的会话中自行运行它。

较新的tf.data.Dataset是一种更简单的解决方案,因为您可以直接在tf.Session中使用它,并且它还与tf.estimator方法兼容:

import tensorflow as tf
import numpy as np

a = np.array([1,2,3,4,5])

dataset = tf.data.Dataset.from_tensor_slices({"x": a})
dataset = dataset.repeat(3) # number of epochs
dataset = dataset.batch(2) # batch size

batch = dataset.make_one_shot_iterator().get_next()

with tf.Session() as sess:
    print(sess.run(batch))