从tf.estimator.EstimatorSpec获取值

时间:2018-01-16 18:39:52

标签: python tensorflow

当PREDICT ModeKey传递给tf.estimator.Estimator时,返回的EstimatorSpec对象是可迭代的。

predictions = { 'class_ids': tf.argmax(input=logits, axis=1), 'logits': logits }
if mode == tf.estimator.ModeKeys.PREDICT:
  return tf.estimator.EstimatorSpec(mode, predictions=predictions)

然后,

prediction_results = classifier.predict(input_fn=prediction_dataset_input)
for x, each in enumerate(prediction_results):
    print(each)

为每个预测生成一行,如下所示;

....
{'class_ids': 1, 'logits': array([-32976400., -30171870.], dtype=float32)}
{'class_ids': 1, 'logits': array([-32958380., -30386898.], dtype=float32)}
{'class_ids': 1, 'logits': array([-32940332., -30601930.], dtype=float32)}
{'class_ids': 1, 'logits': array([-32922300., -30816956.], dtype=float32)}
....

我怎样才能让TRAIN ModeKey返回可迭代的东西?理想情况下,我想在训练中定义的步骤中打印使用的每个特征值和计算的logit。我要回来了;

if mode == tf.estimator.ModeKeys.TRAIN:
    return tf.estimator.EstimatorSpec(
            mode,
            loss=loss,
            predictions=predictions,
        train_op=train_objective)

然后,

prediction_results = classifier.train(input_fn=lambda: dataset_input_fn(
                file_dataset_feed,
                perform_shuffle=True,
                repeat_count=100))
for x, each in enumerate(prediction_results):
    print(each)

产生

TypeError: 'Estimator' object is not iterable

0 个答案:

没有答案