随机批量生成器Tensorflow

时间:2017-07-24 13:08:19

标签: tensorflow queue

我想通过在Tensorflow中排队我的数据(x_data和y_target)来开发随机批量生成。这是我的代码:

x_data=np.asarray(x_data) # shape (2404,60,41,2)
y_target=np.asarray(y_data) # shape (2404,7)

queue_x_data = tf.placeholder(tf.float64, shape=[20,60,41,2])
queue_y_target = tf.placeholder(tf.int64, shape=[20,7])
q = tf.FIFOQueue(capacity=50, dtypes=[tf.float64,tf.int64], shapes=[(60,41,2), (7)])
init = q.enqueue_many([queue_x_data, queue_y_target])
dequeue_op = q.dequeue()
data_batch, target_batch = tf.train.shuffle_batch(dequeue_op,batch_size=32,num_threads=4,capacity=50000, min_after_dequeue=10000)

with tf.Session() as sess:
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  curr_data_batch, curr_target_batch = sess.run([data_batch, target_batch])

这里我会尝试打印生成的批次。

  print(curr_data_batch)
  coord.request_stop()
  coord.join(threads)

但是我无法打印任何生成的批次。我究竟做错了什么?感谢

2 个答案:

答案 0 :(得分:0)

您尝试从占位符填充队列,但占位符必须通过转移free_dict参数来填充数据,然后您调用sess.run()

答案 1 :(得分:0)

有两个原因。

  • 首先,您从未致电init操作:sess.run(init, {queue_x_data: x_data, queue_y_target: y_target}))
  • 其次,min_after_dequeue=10000与队列中的元素数量相比较高,尝试将其设置为较低级别。