不使用tf.session()获取Tensorflow model.predict的值

时间:2018-08-23 20:46:15

标签: python tensorflow prediction tensorflow-estimator

所以我有一个使用tensorflow的新estimator API编写的预测网络

我这样称呼预测:

prediction = classifier_mycolumn.predict(cust_train_fn)

其中classifier_mycolumn定义为:

classifier_mycolumn = tf.estimator.Estimator(
        model_fn = model_function,
        params = {'feature_columns': my_feature_columns,
        'n_classes': 1,},
        model_dir=current_model)

只是标准的tf.estimator.Estimator东西。

cust_train_fn实际上不是一个函数,它是网络应该接受并对其进行推理的单个数组。

网络应该返回0或1(最后是softmax的简单分类)-但是每次我尝试打印prediction时-我只会得到<generator object Estimator.predict at 0x7fe0e8630c50>

有什么办法可以得到预测的实际结果?例如,网络返回0还是1?非常感谢您阅读。

0 个答案:

没有答案