我跟随tutorial在tensorflow.org上生成输入函数。
一切正常,直到我尝试打印预测(它应该只有6个预测)。
y = regressor.predict(input_fn=lambda: input_fn(prediction_set))
print ("Predictions: {}".format(str(y)))
我得到了这个输出:<generator object _as_iterable at 0x7fa66ec6cfa0>
。
如果我尝试将生成器转换为list(y)
的列表。程序冻结了。
如果我尝试只获得前30个项目(即使应该只有6个):
import itertools
print(list(itertools.islice(y, 30)))
我得到以下内容:
[34.382435, 20.170452, 23.214834, 37.211243, 17.090082, 19.648254,
34.382435, 20.170452, 23.214834, 37.211243, 17.090082, 19.648254,
34.382435, 20.170452, 23.214834, 37.211243, 17.090082, 19.648254,
34.382435, 20.170452, 23.214834, 37.211243, 17.090082, 19.648254,
34.382435, 20.170452, 23.214834, 37.211243, 17.090082, 19.648254]
正如您所看到的,无限期地重复这些价值。
我错过了什么吗?
Tensorflow版本:0.12.0-rc1
Python版本:2.7.6
答案 0 :(得分:0)
我们最近使用tf.contrib.learn修改了#34;构建输入函数&#34;教程是最新的Estimators中的predict()的最新默认行为,它返回一个生成器。修订后的教程文本在这里:
https://www.tensorflow.org/versions/master/tutorials/input_fn/
最新的代码在这里:
以下是使用itertools.islice的相关部分,与上次尝试时一样:
y = regressor.predict(input_fn=lambda: input_fn(prediction_set))
# .predict() returns an iterator; convert to a list and print predictions
predictions = list(itertools.islice(y, 6))
print("Predictions: {}".format(str(predictions)))
当islice的第二个参数增加到30时,我没有看到预测值重复的行为。您是否可以尝试从上面的链接中提取GitHub中的最新代码,如果您仍然遇到该问题,请告诉我?