如何在Tensorflow中有效实现固定大小的缓冲区

时间:2019-01-15 02:51:07

标签: python multithreading tensorflow deep-learning

我正在执行在线RL算法,并且在训练期间会生成训练数据。因此,我需要一个固定大小的缓冲区来存储新生成的训练数据并放弃旧数据。

有两个主要功能:

  1. 将新生成的数据添加到缓冲区中,如果数据量超出容量,则放弃旧数据以满足容量。
  2. 从缓冲区随机抽样一批。

这在Tensorflow外部使用双端队列很容易使用python:

from collections import deque
import random 

Q = deque(range(100))
Q.extendleft(range(50)) # add operation

batch = random.sample(Q, 10) # batch operation

我已经尝试使用tf.FIFOQueue来做到这一点。

N_SAMPLES = 100
NUM_THREADS = 4
all_data = 10 * np.random.randn(N_SAMPLES, 4) + 1
all_target = np.random.randint(0, 2, size=N_SAMPLES)
queue = tf.FIFOQueue(capacity=50, dtypes=[tf.float32, tf.int32], shapes=[[4], []])
enqueue_op = queue.enqueue_many([all_data, all_target])
data_sample, label_sample = queue.dequeue()
qr = tf.train.QueueRunner(queue, [enqueue_op] * NUM_THREADS)
with tf.Session() as sess:
    # create a coordinator, launch the queue runner threads.
    coord = tf.train.Coordinator()
    enqueue_threads = qr.create_threads(sess, coord=coord, start=True)
    for step in range(100): # do to 100 iterations
        if coord.should_stop():
            break
        one_data, one_label = sess.run([data_sample, label_sample])
        print(one_data, one_label)
    coord.request_stop()
    coord.join(enqueue_threads)

但是似乎很难随机抽取一批数据进行训练。所以我想知道,Tensorflow内部是否还有其他功能支持该功能?例如,Tensorflow正式推荐的tf.data.Dataset?

而且,此数据结构应支持多线程添加和采样操作,以进行分布式训练。

0 个答案:

没有答案