Tensorflow - 获取队列中的样本数量?

时间:2016-10-22 11:03:16

标签: python tensorflow

对于性能监控,我想关注当前排队的示例。我正在平衡我用于填充队列的线程数量以及队列的最佳最大大小。 我如何获得这些信息?我使用的是tf.train.batch(),但我想这些信息可能位于FIFOQueue的某个地方? 我原以为这是一个局部变量,但我还没有找到它。

1 个答案:

答案 0 :(得分:5)

tldr:如果您的队列是由tf.batch创建的,则可以使用sess.run("batch/fifo_queue_Size:0")

获取大小

FIFOQueue对象提供了一个size()方法,该方法创建一个op,它在队列中提供多个元素。但是,如果您使用tf.batch,则会在方法内部创建FIFOQueue,并且此对象不会在外部公开。

特别是你在input.py

中看到了这一点
queue = _which_queue(dynamic_pad)(
    capacity=capacity, dtypes=types, shapes=shapes, shared_name=shared_name)
print("Enqueueing: ", enqueue_many, tensor_list, shapes)
_enqueue(queue, tensor_list, num_threads, enqueue_many)
summary.scalar("queue/%s/fraction_of_%d_full" % (queue.name, capacity),
               math_ops.cast(queue.size(), dtypes.float32) *
               (1. / capacity))

由于queue是本地的,因此无法保留其size()方法。但是,由于为了构造摘要而调用了size(),因此图中会显示相应的size op,您可以按名称调用它。您可以通过执行此类操作找到节点的名称

x = tf.constant(1)
q = tf.train.batch([x], 2)
tf.get_default_graph().as_graph_def()

你会看到

node {
  name: "batch/fifo_queue_Size"
  op: "QueueSize"
  input: "batch/fifo_queue"
  attr {
    key: "_class"
    value {
      list {

由此可以看出batch/fifo_queue_Size是op的名称,因此batch/fifo_queue_Size:0是第一个输出的名称,因此您可以通过执行以下操作来获得大小:

sess.run("batch/fifo_queue_Size:0")

如果您有多个batch操作,则会自动将这些名称重新导入batch_1/fifo_queue_Sizebatch_2/fifo_queue_Size

或者,您可以使用tf.batch(...name="mybatch")呼叫您的节点,然后张量的名称将为mybatch/fifo_queue_Size:0