我如何检查tf.estimator.inputs.numpy_input_fn的内容?

时间:2018-06-11 02:08:45

标签: python tensorflow

我想反复训练一组数据的张量流图,我认为tf.estimator.inputs.numpy_input_fn可能是我正在寻找的。我发现批量大小,重复,时期和迭代器之间的区别令人难以置信地混淆,因此我开始尝试检查数据集的内容以试图弄清楚实际发生了什么。但是,每当我尝试这样做时,我的程序就会挂起。

这是我想出的最复杂的测试案例:

{
  "took": 0,
  "timed_out": false,
  "_shards": {
    "total": 5,
    "successful": 5,
    "skipped": 0,
    "failed": 0
  },
  "hits": {
    "total": 0,
    "max_score": null,
    "hits": []
  }
}

该程序输出

import tensorflow as tf
import numpy

class TestMock(tf.test.TestCase):
    def test(self):
        inputs = numpy.array(range(10))
        targets = numpy.array(range(10,20))

        input_fn = tf.estimator.inputs.numpy_input_fn(
            x=inputs,
            y=targets,
            batch_size=1,
            num_epochs=2,
            shuffle=False)

        print input_fn()
        with self.test_session() as sess:
            # sess.run(input_fn()[0]) # it'll hang if I run this
            pass

if __name__ == '__main__':
    tf.test.main()

这似乎是合理的,但是一旦我尝试运行(<tf.Tensor 'fifo_queue_DequeueUpTo:1' shape=(?,) dtype=int64>, <tf.Tensor 'fifo_queue_DequeueUpTo:2' shape=(?,) dtype=int64>) 行,我的程序会冻结,我必须终止这个过程。我在这里做错了什么?

我想要做的是确保我加入我的流程的数据实际上是我认为的,但我认为如果没有检查数据的能力我就能做到这一点

1 个答案:

答案 0 :(得分:2)

从上面的打印语句中我们可以推断input_fn返回queue ops,我们需要使用start_queue_runners and Coordinator运行它们:

 features_op, labels_op = input_fn()
 with tf.Session() as sess:
     # initialise and start the queues.
     sess.run(tf.local_variables_initializer())

     coordinator = tf.train.Coordinator()
     _ = tf.train.start_queue_runners(coord=coordinator)

    print(sess.run([features_op, labels_op]))

    #[array([0]), array([10])]