打印estimator.predict张量的问题

时间:2018-02-05 07:30:35

标签: python tensorflow tensorflow-estimator

我目前正在使用tf 1.4,我需要帮助查看tf.contrib.factorization.KMeansClustering估算器的预测。我目前的代码段如下:

km = KMeansClustering(num_clusters=8,initial_clusters=KMeansClustering.KMEANS_PLUS_PLUS_INIT,model_dir=MODEL,relative_tolerance=0.01)

result = km.train(input_fn=lambda: gen_input(body))

input_fn = tf.estimator.inputs.pandas_input_fn(x={'x':tst}, shuffle=False)

y = result.predict(input_fn)

body和tst是pandas数据帧。 print(y)给出:

<generator object Estimator.predict at 0x11ebecba0>

尝试我搜索的内容,例如调用print(list(y))print(next(y))或迭代y,如:

for i in y:
    ...

for i in y.items():
    ...

for i in enumerate(y):
    ...

等,给出错误TypeError: data must be either a numpy array or pandas DataFrame if pandas is installed; got dict。我找不到任何其他方法来尝试在线打印。感谢

1 个答案:

答案 0 :(得分:-1)

您的代码太少,无法确认错误/遗漏。此外,预计至少会有完整的堆栈跟踪。在您添加更多信息后,此答案可能会很好。

您是否期望拨打pandas_input_fn的电话返回的内容超出您的预期?它返回一个带有签名()->(dict of features, target)的函数。有关详细信息,请参阅docs

此外,您似乎没有运行TensorFlow session。在你这样做之前,所有的张量,计算(在你的情况下都是预测)等只是图形的一部分,它们只有在运行TF会话后才会有值。

有关详细信息,请参阅这些docs