无法获得tensorflow DNNClassifier的预测

时间:2016-11-20 15:18:43

标签: python tensorflow

我正在使用MNIST教程中的代码:

feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]
classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
                                            hidden_units=[10, 20, 10],
                                            n_classes=2,
                                            model_dir="/tmp/iris_model")

classifier.fit(x=np.array(train, dtype = 'float32'),
               y=np.array(y_tr, dtype = 'int64'),
               steps=2000)

accuracy_score = classifier.evaluate(x=np.array(test, dtype = 'float32'),
                                     y=y_test)["auc"]
print('AUC: {0:f}'.format(accuracy_score))

from tensorflow.contrib.learn import SKCompat
ds_test_ar = np.array(ds_test, dtype = 'float32')

ds_predict_tf = classifier.predict(input_fn = _my_predict_data)
print('Predictions: {}'.format(str(ds_predict_tf)))

但最后我得到了以下结果而不是预测:

Predictions: <generator object DNNClassifier.predict.<locals>.<genexpr> at 0x000002CE41101CA8>

我做错了什么?

5 个答案:

答案 0 :(得分:13)

您收到并保存到ds_predict_tf的内容是生成器表达式。 要打印它,你可以这样做:

for i in ds_predict_tf:
    print i

print(list(ds_predict_tf))

您可以阅读有关genexpr here的更多信息。

答案 1 :(得分:9)

  

DNNC分类器预测功能默认为 as_iterable = True 。因此,它返回一个生成器。要获取预测值而不是生成器,请在classifier.predict方法中传递 as_iterable = False

例如,

  

classifier.predict(input_fn = _my_predict_data,as_iterable=False)

了解有关分类器方法和参数的更多信息。以下是预测方法的文档的一部分。

来自DNNClassifier文档:

预测

  

参数数量:

  • x:features。
  • input_fn:输入功能。如果设置,则x必须为None。
  • batch_size:覆盖默认批量大小。
  • 输出:str列表,要预测的输出名称。如果为None,则返回类。
  • as_iterable:如果为True,则返回一个迭代,该迭代继续为每个示例产生预测,直到输入耗尽为止。注意:如果您希望迭代终止,则输入必须终止(例如,如果您使用的是read_batch_features,请确保传递num_epochs = 1。)

  

返回:

  • 具有形状[batch_size]的预测类的Numpy数组(如果as_iterable为True,则或预测类的可迭代)。每个预测类由其类索引(即从0到n_classes-1的整数)表示。如果设置了输出,则返回预测字典。

答案 2 :(得分:1)

<强>解决方案: -

pred = classifier.fit(x=training_set.data, y=training_set.target, steps=2000).predict(test_set.data)

print ("Predictions:")

print(list(pred))

那是......

答案 3 :(得分:1)

尽可能接近教程使用:

print('Predictions: {}' .format(list(ds_predict_tf)))

答案 4 :(得分:0)

很抱歉,答案非常简单,您需要使用predictor作为generator对象:

g1 = ds_predict_tf

[g1.__next__() for i in range(100)]