预测输出无限重复

时间:2016-12-21 02:49:57

标签: python tensorflow

我跟随tutorial在tensorflow.org上生成输入函数。

一切正常,直到我尝试打印预测(它应该只有6个预测)。

y = regressor.predict(input_fn=lambda: input_fn(prediction_set))
print ("Predictions: {}".format(str(y)))

我得到了这个输出:<generator object _as_iterable at 0x7fa66ec6cfa0>

如果我尝试将生成器转换为list(y)的列表。程序冻结了。

如果我尝试只获得前30个项目(即使应该只有6个):

import itertools
print(list(itertools.islice(y, 30)))

我得到以下内容:

[34.382435, 20.170452, 23.214834, 37.211243, 17.090082, 19.648254,
34.382435, 20.170452, 23.214834, 37.211243, 17.090082, 19.648254,
34.382435, 20.170452, 23.214834, 37.211243, 17.090082, 19.648254,
34.382435, 20.170452, 23.214834, 37.211243, 17.090082, 19.648254,
34.382435, 20.170452, 23.214834, 37.211243, 17.090082, 19.648254]

正如您所看到的,无限期地重复这些价值。

我错过了什么吗?

Tensorflow版本:0.12.0-rc1

Python版本:2.7.6

1 个答案:

答案 0 :(得分:0)

我们最近使用tf.contrib.learn修改了#34;构建输入函数&#34;教程是最新的Estimators中的predict()的最新默认行为,它返回一个生成器。修订后的教程文本在这里:

https://www.tensorflow.org/versions/master/tutorials/input_fn/

最新的代码在这里:

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/tutorials/input_fn/boston.py

以下是使用itertools.islice的相关部分,与上次尝试时一样:

y = regressor.predict(input_fn=lambda: input_fn(prediction_set))
# .predict() returns an iterator; convert to a list and print predictions                                                      
predictions = list(itertools.islice(y, 6))
print("Predictions: {}".format(str(predictions)))

当islice的第二个参数增加到30时,我没有看到预测值重复的行为。您是否可以尝试从上面的链接中提取GitHub中的最新代码,如果您仍然遇到该问题,请告诉我?