我目前正试图摆脱使用Feed并开始使用队列以支持更大的数据集。使用队列对张量流中的优化器工作正常,因为它们仅为每个出列操作计算一次渐变。但是,我已经与执行行搜索的其他优化器实现了接口,我不仅需要评估渐变,还要评估同一批次的多个点的损失。不幸的是,对于正常的排队系统,每次损失评估都会执行一次出列而不是多次计算同一批次。
有没有办法将出列操作与梯度/损耗计算分离,以便我可以执行一次出队,然后在当前批次上多次执行梯度/损耗计算?
编辑:请注意,我的输入张量的大小在批次之间是可变的。我们使用分子数据,每个分子都有不同数量的原子。这与图像数据完全不同,图像数据通常按比例缩放以具有相同的尺寸。
答案 0 :(得分:4)
通过创建变量存储出列值来解耦,然后依赖于此变量而不是出列操作。推进队列发生在assign
解决方案#1 :固定大小的数据,使用变量
(image_batch_live,) = tf.train.batch([image],batch_size=5,num_threads=1,capacity=614)
image_batch = tf.Variable(
tf.zeros((batch_size, image_size, image_size, color_channels)),
trainable=False,
name="input_values_cached")
advance_batch = tf.assign(image_batch, image_batch_live)
现在image_batch
提供队列的最新值而不推进它,advance_batch
推进队列。
解决方案#2 :可变大小数据,使用持久性张量
我们通过引入dequeue_op
和dequeue_op2
来解决工作流程。所有计算都取决于dequeue_op2
,其中dequeue_op
的保存值。使用get_session_tensor/get_session_handle
可确保实际数据保留在TensorFlow运行时中,并且通过feed_dict
传递的值是一个短字符串标识符。由于dummy_handle
,API有点尴尬,我提出了这个问题here
import tensorflow as tf
def create_session():
sess = tf.InteractiveSession(config=tf.ConfigProto(operation_timeout_in_ms=3000))
return sess
tf.reset_default_graph()
sess = create_session()
dt = tf.int32
dummy_handle = sess.run(tf.get_session_handle(tf.constant(1)))
q = tf.FIFOQueue(capacity=20, dtypes=[dt])
enqueue_placeholder = tf.placeholder(dt, shape=[None])
enqueue_op = q.enqueue(enqueue_placeholder)
dequeue_op = q.dequeue()
size_op = q.size()
dequeue_handle_op = tf.get_session_handle(dequeue_op)
dequeue_placeholder, dequeue_op2 = tf.get_session_tensor(dummy_handle, dt)
compute_op1 = tf.reduce_sum(dequeue_op2)
compute_op2 = tf.reduce_sum(dequeue_op2)+1
# fill queue with variable size data
for i in range(10):
sess.run(enqueue_op, feed_dict={enqueue_placeholder:[1]*(i+1)})
sess.run(q.close())
try:
while(True):
dequeue_handle = sess.run(dequeue_handle_op) # advance the queue
val1 = sess.run(compute_op1, feed_dict={dequeue_placeholder: dequeue_handle.handle})
val2 = sess.run(compute_op2, feed_dict={dequeue_placeholder: dequeue_handle.handle})
size = sess.run(size_op)
print("val1 %d, val2 %d, queue size %d" % (val1, val2, size))
except tf.errors.OutOfRangeError:
print("Done")
运行时应该会看到类似下面的内容
val1 1, val2 2, queue size 9
val1 2, val2 3, queue size 8
val1 3, val2 4, queue size 7
val1 4, val2 5, queue size 6
val1 5, val2 6, queue size 5
val1 6, val2 7, queue size 4
val1 7, val2 8, queue size 3
val1 8, val2 9, queue size 2
val1 9, val2 10, queue size 1
val1 10, val2 11, queue size 0
Done