在tensorflow中,如何从生成器中读取我的预测?

时间:2017-08-28 06:12:18

标签: python tensorflow machine-learning prediction tensorflow-estimator

我一直在调整以下tensorflow教程中最卷积的网络: https://www.tensorflow.org/tutorials/layers

除输入的形状外,我使用相同的代码。当我训练和评估时,我得到了很好的结果。但我希望看到一个预测,以便我可以知道错误分类的内容。但是在运行时

y=SN_classifier.predict(input_fn=my_data_to_predict)

其中my_data_to_predict是一个正确形状的numpy数组,我得到以下输出:

<generator object Estimator.predict at 0x7fb1ecefeaf0>

我在论坛上看到我应该能够读到它:     因为我在y:       打印(ⅰ)

但它提出了     &#39; numpy.ndarray&#39;对象不可调用

如果我尝试,也会发生同样的事情:

print('Predictions: {}'.format(list(y))

我在其他论坛上阅读..

您是否知道为什么它不会输出我的预测?

以下是我定义预测的代码部分:

predictions = {
      # Generate predictions (for PREDICT and EVAL mode)
      "classes": tf.argmax(input=logits, axis=1),
      # Add `softmax_tensor` to the graph. It is used for PREDICT and by the
      # `logging_hook`.
      "probabilities": tf.nn.softmax(logits, name="softmax_tensor")
    }
    if mode == tf.estimator.ModeKeys.PREDICT:
        return(tf.estimator.EstimatorSpec(mode=mode, predictions=predictions))

我称之为:

y=SN_classifier.predict(input_fn=my_data_to_predict)

非常感谢你的帮助,我会接受任何建议,想法:)

3 个答案:

答案 0 :(得分:3)

let arr = [{ "LOCATION_ID": 1001, "LOCATIONS": [{ "LOCATION_ID": 2001, "LOCATIONS": [{ "LOCATION_ID": 2002, "LOCATIONS": [{ "LOCATION_ID": 3002 }] }] }] }, { "LOCATION_ID": 5001 } ]; function getMoreData(locationId){ return "some data for " + locationId; } function recursivelyTraverseArr(locations){ // get an array of locations for(let location of locations){ if(location.LOCATION_ID){ // if location_id exists add new property location.data = getMoreData(location.LOCATION_ID); } if(location.LOCATIONS){ // if locations exist recursively call inside recursivelyTraverseArr(location.LOCATIONS); } } } recursivelyTraverseArr(arr); console.log(arr);应该是一个生成张量的函数。将它包装在numpy_input_fn中应该是您所需要的一切。

input_fn

答案 1 :(得分:1)

预测函数返回一个生成器,因此您可以一次获得包含所有预测的整个字典。

predictor = SN_classifier.predict(input_fn=my_data_to_predict)

# this is how to get your results:

predictions_dict = next(predictor)

答案 2 :(得分:0)

有一种方法可以将 classifier.predict 函数返回的生成器打开,只需将生成器包装在“ 列表”中即可:

predictor = SN_classifier.predict(input_fn=my_data_to_predict);
results   = list(predictor);
tf.logging.info(results);