在python multiprocessing Queue中推进tensorflow数据集迭代器

时间:2019-03-05 06:25:48

标签: python tensorflow tensorflow-datasets

在此示例中,有什么方法可以移动迭代器?

import tensorflow as tf
import numpy as np
from multiprocessing import Process, Queue

def store(batch, queue):
    while True:
        queue.put(batch)


if __name__=='__main__':
    pqueue = Queue()
    a1 = np.arange(1000)

    m = tf.data.Dataset.from_tensor_slices(a1).repeat().batch(1)
    iter_m = m.make_one_shot_iterator()
    m_init_ops = iter_m.make_initializer(m)
    next_m = iter_m.get_next()

    with tf.Session() as sess:
        batch = sess.run(next_m)
        pp_process = Process(target=store,args=(batch, pqueue,))
        pp_process.daemon = True
        pp_process.start()

        for i in range(10):
            print(pqueue.get())

我的想法是将处理后的数据存储在可由tensorflow访问以进行训练的队列中,不幸的是我无法推进迭代器。任何建议将不胜感激。

当前输出为

[0]
[0]
[0]
[0]
[0]
[0]
[0]
[0]
[0]
[0]

1 个答案:

答案 0 :(得分:1)

Tensorflow多线程

迭代器不前进,因为从技术上讲,您只执行一次get_next操作:sess.run(next_m)。如果仅使用张量流多线程,则只需将其移至store函数中即可获得所需的结果:

def store(sess, next_m, queue):
    while True:
        queue.put(sess.run(next_m))

# batch = sess.run(next_m) <- Remove
pp_process = Thread(target=store,args=(sess, next_m, pqueue,)) # <- Thread with correct args passed

Tensorflow多处理

但是,对于多处理,由于会话对象不可序列化,因此还应确保在创建会话后再也不要实例化(分叉)新进程。
就您而言,您可以简单地在store函数中创建一个新会话,并在分叉之后启动主会话:

from multiprocessing import Process, Queue

import numpy as np
import tensorflow as tf


def store(next_m, queue):
    with tf.Session() as sess:
        while True:
            queue.put(sess.run(next_m))


if __name__ == '__main__':
    ...
    pp_process = Process(target=store, args=(next_m, pqueue,))
    pp_process.daemon = True
    pp_process.start() # <- Fork before starting this session!

    with tf.Session() as sess:
        for i in range(10):
            print(pqueue.get())