所以我有一个使用tensorflow的新estimator API编写的预测网络
我这样称呼预测:
prediction = classifier_mycolumn.predict(cust_train_fn)
其中classifier_mycolumn
定义为:
classifier_mycolumn = tf.estimator.Estimator(
model_fn = model_function,
params = {'feature_columns': my_feature_columns,
'n_classes': 1,},
model_dir=current_model)
只是标准的tf.estimator.Estimator东西。
cust_train_fn
实际上不是一个函数,它是网络应该接受并对其进行推理的单个数组。
网络应该返回0或1(最后是softmax的简单分类)-但是每次我尝试打印prediction
时-我只会得到<generator object Estimator.predict at 0x7fe0e8630c50>
有什么办法可以得到预测的实际结果?例如,网络返回0还是1?非常感谢您阅读。