Estimator预测无限循环

时间:2017-12-17 15:57:06

标签: python tensorflow google-cloud-ml

我不了解如何使用TensorFlow Estimator API进行单一预测 - 我的代码会导致无限循环,无法预测相同的输入。

根据documentation,当input_fn引发StopIteration异常时,预测应该停止:

  

input_fn:返回功能的输入函数,它是一个字典   字符串功能名称为Tensor或SparseTensor。如果它返回一个元组,   第一项被提取为要素。预测一直持续到   input_fn引发输入结束异常(OutOfRangeError或   StopIteration异常)。

这是我的代码中的相关部分:

classifier = tf.estimator.Estimator(model_fn=image_classifier, model_dir=output_dir,
                                    config=training_config, params=hparams)

def make_predict_input_fn(filename):
    queue = [ filename ]
    def _input_fn():
        if len(queue) == 0:
            raise StopIteration
        image = model.read_and_preprocess(queue.pop())
        return {'image': image}
    return _input_fn

predictions = classifier.predict(make_predict_input_fn('garden-rose-red-pink-56866.jpeg'))
for i, p in enumerate(predictions):
    print("Prediction %s: %s" % (i + 1, p["class"]))

我错过了什么?

2 个答案:

答案 0 :(得分:0)

That's because input_fn() needs to be a generator. Change your function to (yield instead of return):

def make_predict_input_fn(filename):
    queue = [ filename ]
    def _input_fn():
        if len(queue) == 0:
            raise StopIteration
        image = model.read_and_preprocess(queue.pop())
        yield {'image': image}
    return _input_fn

答案 1 :(得分:0)

一种解决方案是使用itertools.islice:

import itertools

predictions = itertools.islice(predictions, number_of_samples)

for i, p in enumerate(predictions):
    print("Prediction %s: %s" % (i + 1, p["class"]))

number_of_samples是一个整数,它是迭代器的停止点。