如何在TensorFlow中混合基于队列和基于Feed的输入

时间:2016-06-02 02:28:56

标签: python-2.7 tensorflow

我最近迁移到一个full_connected样式模型,该模型从TFRecords文件生成的队列中读取输入。这已经证明效率更高,但我仍然想用placeholder / feed_dict以交互方式传递参数。

对于feed_dict和full_connected功能,是否有办法使用相同的计算图(假设您有一个在init方法中构建图的模型类)?你能获得占位符来接收出列值吗?

2 个答案:

答案 0 :(得分:8)

一种可能性是使用最近添加的(在TensorFlow 0.8中)tf.placeholder_with_default() op,它允许您指定默认值(通常是队列/阅读器的输出),并且还允许您提供值可能有不同的形状。

例如,假设您的队列生成32个元素的批次,其中每个元素具有784个特征,以提供32 x 784矩阵。

input_from_queue = ...  # e.g. `queue.dequeue_many(32)` or `tf.train.batch(..., 32)`
# input_from_queue.get_shape() ==> (32, 784)

input = tf.placeholder_with_default(input_from_queue, shape=(None, 784))
# input.get_shape() ==> (?, 784)

# ...
train_op = ...

sess.run(train_op)  # Takes examples from `queue`.
sess.run(train_op, feed_dict={input: ...})  # Takes examples from `feed_dict`.

这允许您根据需要提供可变大小的批次或使用输入阅读器。

答案 1 :(得分:3)

您可以简单地输入出列操作的输出。 TensorFlow实际上不会将任何项目出列,它只会使用您提供的值。例如:

q = tf.FIFOQueue(capacity=10, dtypes=[tf.float32], shapes=[()])
v = tf.placeholder(tf.float32)
enqueue = q.enqueue([v])
dequeue = q.dequeue()
output = dequeue + 10.0

with tf.Session() as sess:
    sess.run(enqueue, feed_dict={v: 1.0})
    sess.run(enqueue, feed_dict={v: 2.0})
    sess.run(enqueue, feed_dict={v: 3.0})
    print(sess.run(output)) # 11.0
    print(sess.run(output, feed_dict={dequeue: 5.0})) # 15.0
    print(sess.run(output)) # 12.0
    print(sess.run(output)) # 13.0