在Tensorflow中获取占位符

时间:2018-07-10 00:44:08

标签: tensorflow

我从一个项目中看到了Tensorflow代码,如下所示:

sess.run(train_enqeue, feed_dict)

train_enqeue是从以下位置构建的:

train_queue = tf.FIFOQueue(train_params.async_encoding, [x.dtype for x in placeholders], name="train_queue")
train_enqeue = train_queue.enqueue(placeholders)

基本上是占位符的FIFO队列。我想知道在这种情况下传递占位符是什么意思?它是否从占位符返回值?

代码来自https://github.com/allenai/document-qa/blob/master/docqa/trainer.py的561行

1 个答案:

答案 0 :(得分:2)

它不是占位符的FIFO队列。它是张量的FIFO队列。占位符要求指定应将哪些值添加到队列中。

dequeue返回/获取被排队的元素:

import tensorflow as tf

input_a = tf.placeholder(tf.int32)
input_b = tf.placeholder(tf.float32)

queue = tf.FIFOQueue(20, [tf.int32, tf.float32], name="train_queue")
queue_add = queue.enqueue([input_a, input_b])
queue_fetch = queue.dequeue()

with tf.Session() as sess:
    sess.run(queue_add, {input_a: 42, input_b: 3.14159265358979})
    sess.run(queue_add, {input_a: 43, input_b: 4.14159265358979})
    sess.run(queue_add, {input_a: 44, input_b: 5.14159265358979})
    print(sess.run(queue_fetch))  # gives [42, 3.1415927]
    print(sess.run(queue_fetch))  # gives [43, 4.1415925]
    print(sess.run(queue_fetch))  # gives [44, 5.1415925]

为了DRY,您可以重写:

inputs = []
inputs.append(tf.placeholder(tf.int32))
inputs.append(tf.placeholder(tf.float32))

queue = tf.FIFOQueue(20, [x.dtype for x in inputs], name="train_queue")

代替

input_a = tf.placeholder(tf.int32)
input_b = tf.placeholder(tf.float32)

queue = tf.FIFOQueue(20, [tf.int32, tf.float32], name="train_queue")

假设有50个输入,并且您懒于编写所有这些类型,或者您只喜欢通用实现。

Permanent-link to line mentioned in question