在此示例中,有什么方法可以移动迭代器?
import tensorflow as tf
import numpy as np
from multiprocessing import Process, Queue
def store(batch, queue):
while True:
queue.put(batch)
if __name__=='__main__':
pqueue = Queue()
a1 = np.arange(1000)
m = tf.data.Dataset.from_tensor_slices(a1).repeat().batch(1)
iter_m = m.make_one_shot_iterator()
m_init_ops = iter_m.make_initializer(m)
next_m = iter_m.get_next()
with tf.Session() as sess:
batch = sess.run(next_m)
pp_process = Process(target=store,args=(batch, pqueue,))
pp_process.daemon = True
pp_process.start()
for i in range(10):
print(pqueue.get())
我的想法是将处理后的数据存储在可由tensorflow访问以进行训练的队列中,不幸的是我无法推进迭代器。任何建议将不胜感激。
当前输出为
[0]
[0]
[0]
[0]
[0]
[0]
[0]
[0]
[0]
[0]
答案 0 :(得分:1)
迭代器不前进,因为从技术上讲,您只执行一次get_next操作:sess.run(next_m)
。如果仅使用张量流多线程,则只需将其移至store
函数中即可获得所需的结果:
def store(sess, next_m, queue):
while True:
queue.put(sess.run(next_m))
# batch = sess.run(next_m) <- Remove
pp_process = Thread(target=store,args=(sess, next_m, pqueue,)) # <- Thread with correct args passed
但是,对于多处理,由于会话对象不可序列化,因此还应确保在创建会话后再也不要实例化(分叉)新进程。
就您而言,您可以简单地在store函数中创建一个新会话,并在分叉之后启动主会话:
from multiprocessing import Process, Queue
import numpy as np
import tensorflow as tf
def store(next_m, queue):
with tf.Session() as sess:
while True:
queue.put(sess.run(next_m))
if __name__ == '__main__':
...
pp_process = Process(target=store, args=(next_m, pqueue,))
pp_process.daemon = True
pp_process.start() # <- Fork before starting this session!
with tf.Session() as sess:
for i in range(10):
print(pqueue.get())