我希望在CNN
中的培训TF
期间预加载培训数据,我的简单实现如下。但是,我发现了一个奇怪的现象。这似乎是一个同步过程。无论PRE_FETCH
是True
还是False
,加载一个批处理数据的时间成本几乎相同。
class Demo(object):
def __init__(self):
self._name = 'demo'
def load_batch(self):
...
def prefetch(self, func):
while True:
data = func()
self.queue.put(data)
def train(self):
input_data = tf.placeholder(tf.float32, shape=[B, H, W, C])
optim_op = build_model(input_data)
if PRE_FETCH:
self.queue = Queue(30)
self.process = Process(target=self.prefetch, args=(self.load_batch))
self.process.start()
def cleanup():
self.process.terminate()
self.process.join()
import atexit
atexit.register(cleanup)
sess = tf.Session()
i = 1
while i < MAX_ITER_SIZE:
if PRE_FETCH:
start = time.time()
tmp = self.queue.get()
end = time.time()
print 'load data time: ', (end - start)
else:
start = time.time()
tmp = self.load_batch()
end = time.time()
print 'load data time: ', (end - start)
sess.run(optim_op, feed_dict={input_data: tmp}
答案 0 :(得分:0)
通过占位符将数据加载到图表中需要花费时间。如果你希望你的预加载有效,你应该调查替换你的python队列和线程mecanisme与tensorflow in-graph操作。有关如何在tensorflow网站上执行此操作的详细教程:https://www.tensorflow.org/programmers_guide/reading_data