在张量流中手动实现FIFOQueue所面临的问题

时间:2017-12-09 15:16:23

标签: python tensorflow lambda queue

我正在尝试提出一种可以在tensorflow中实现FIFOQueue的方法。因此,在每次迭代时,目的是为placeholder分配一定数量,然后将其存储在Variable名为: 缓冲区 中。每次分配后,我都在递增一个索引。缓冲区大小为[5],因此索引的范围应为0到4.最后,在缓冲区已满后,我将buffer[0:4]设置为buffer[1:5],然后将新值添加到{ {1}}。所以这是我的

buffer[4]

问题:每次调用后import tensorflow as tf import numpy as np import random dim = 30 lst = [] for i in range(dim): lst.append(random.randint(1, 10)) data = np.reshape(lst, [dim, 1]) print(lst) # create a buffer: buffer_input = tf.placeholder(tf.int32, shape=[1]) buffer = tf.Variable(tf.zeros([5], tf.int32)) index = tf.Variable(tf.constant(0)) def fillBufferBeforeFilled(): update_op1 = tf.scatter_update(buffer, indices=[index], updates=buffer_input) index_assign_add = tf.assign_add(index, 1) return update_op1, index_assign_add def fillBufferAfterFilled(): tmp = tf.slice(buffer, begin=[0], size=[4]) update_op2 = tf.scatter_update(buffer, indices=[0, 1, 2, 3], updates=tmp) update_op3 = tf.scatter_update(buffer, indices=[index], updates=buffer_input) return update_op2, update_op3 cond = tf.cond(tf.equal(index, 4), lambda: fillBufferBeforeFilled(), lambda: fillBufferAfterFilled()) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for i in range(dim): cond_ = sess.run(cond, feed_dict={buffer_input: data[i]}) buf = sess.run(buffer, feed_dict={buffer_input: data[i]}) print('buf: ', buf) 变量都不会递增,而index的第一个元素将被赋值给传递给占位符的值。

我想知道为什么我会遇到这种行为,这个问题的解决方案是什么。

非常感谢任何帮助!!

2 个答案:

答案 0 :(得分:0)

你混淆了tf.cond中条件的顺序;它应该是

cond = tf.cond(tf.equal(index, 4), lambda: fillBufferAfterFilled(), lambda: fillBufferBeforeFilled())

我可以让你的代码运行,它主要起作用,但更新不太正确;我怀疑你需要添加一些tf.control_dependencies次调用以强制事情以正确的顺序发生。

答案 1 :(得分:0)

以下是解决方案:

import tensorflow as tf
import numpy as np
import random

dim = 30

lst = []
for i in range(dim):
    lst.append(random.randint(1, 10))

data = np.reshape(lst, [dim, 1])
print(lst)

# create a buffer:
buffer_input = tf.placeholder(tf.int32, shape=[1])

buffer = tf.Variable(tf.zeros([5], tf.int32))

index = tf.Variable(-1, tf.int32)

def fillBufferBeforeFilled():
    index_assign_add = tf.assign_add(index, 1)
    with tf.control_dependencies([index_assign_add]):
        update_op1 = tf.scatter_update(buffer, indices=[index], updates=buffer_input)

    return update_op1, index_assign_add

def fillBufferAfterFilled():
    tmp = tf.slice(buffer, begin=[1], size=[4])
    update_op2 = tf.scatter_update(buffer, indices=[0, 1, 2, 3], updates=tmp)
    with tf.control_dependencies([update_op2]):
        update_op3 = tf.scatter_update(buffer, indices=[index], updates=buffer_input)

    return update_op2, update_op3

cond = tf.cond(tf.equal(index, 4), lambda: fillBufferAfterFilled(), lambda: fillBufferBeforeFilled())

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(dim):
        cond_ = sess.run(cond, feed_dict={buffer_input: data[i]})
        buf = sess.run(buffer, feed_dict={buffer_input: data[i]})
        print('buf: ', buf)