我试图了解如何使用输入函数和估算器。当我尝试使用numpy输入函数时,tensorflow会挂起。我在这里做错了吗?我在MacOS 10.13上使用Python 3.6.4运行TensorFlow 1.6.0。
以下是挂起的代码示例:
import tensorflow as tf
import numpy as np
a = np.array([1,2,3,4,5])
infun = tf.estimator.inputs.numpy_input_fn(
x={"x": a},
batch_size=2,
num_epochs=3,
shuffle=False)
batch = infun()
with tf.Session() as sess:
print(sess.run(batch))
任何帮助将不胜感激。谢谢!
答案 0 :(得分:0)
tf.estimator.inputs.numpy_input_fn
方法挂起,因为它需要队列运行器来提供输入。队列运行程序通常由tf.estimator
方法处理,因此要在您必须设置并启动队列的会话中自行运行它。
较新的tf.data.Dataset
是一种更简单的解决方案,因为您可以直接在tf.Session
中使用它,并且它还与tf.estimator
方法兼容:
import tensorflow as tf
import numpy as np
a = np.array([1,2,3,4,5])
dataset = tf.data.Dataset.from_tensor_slices({"x": a})
dataset = dataset.repeat(3) # number of epochs
dataset = dataset.batch(2) # batch size
batch = dataset.make_one_shot_iterator().get_next()
with tf.Session() as sess:
print(sess.run(batch))