在TF中预加载数据

时间:2017-04-19 02:31:03

标签: python tensorflow queue

我希望在CNN中的培训TF期间预加载培训数据,我的简单实现如下。但是,我发现了一个奇怪的现象。这似乎是一个同步过程。无论PRE_FETCHTrue还是False,加载一个批处理数据的时间成本几乎相同。

class Demo(object):
    def __init__(self):
        self._name = 'demo'

    def load_batch(self):
        ...

    def prefetch(self, func):
        while True:
            data = func()
            self.queue.put(data)

    def train(self):
        input_data = tf.placeholder(tf.float32, shape=[B, H, W, C])
        optim_op = build_model(input_data)

        if PRE_FETCH:
            self.queue = Queue(30)
            self.process = Process(target=self.prefetch, args=(self.load_batch))
            self.process.start()
            def cleanup():
                self.process.terminate()
                self.process.join()
            import atexit
            atexit.register(cleanup)
        sess = tf.Session()
        i = 1
        while i < MAX_ITER_SIZE:
            if PRE_FETCH:
                start = time.time()
                tmp = self.queue.get()
                end = time.time()
                print 'load data time: ', (end - start)
            else:
                start = time.time()
                tmp = self.load_batch()
                end = time.time()
                print 'load data time: ', (end - start)
            sess.run(optim_op, feed_dict={input_data: tmp}

1 个答案:

答案 0 :(得分:0)

通过占位符将数据加载到图表中需要花费时间。如果你希望你的预加载有效,你应该调查替换你的python队列和线程mecanisme与tensorflow in-graph操作。有关如何在tensorflow网站上执行此操作的详细教程:https://www.tensorflow.org/programmers_guide/reading_data