我正在尝试提出一种可以在tensorflow中实现FIFOQueue
的方法。因此,在每次迭代时,目的是为placeholder
分配一定数量,然后将其存储在Variable
名为: 缓冲区 中。每次分配后,我都在递增一个索引。缓冲区大小为[5],因此索引的范围应为0到4.最后,在缓冲区已满后,我将buffer[0:4]
设置为buffer[1:5]
,然后将新值添加到{ {1}}。所以这是我的
码
buffer[4]
问题:每次调用后import tensorflow as tf
import numpy as np
import random
dim = 30
lst = []
for i in range(dim):
lst.append(random.randint(1, 10))
data = np.reshape(lst, [dim, 1])
print(lst)
# create a buffer:
buffer_input = tf.placeholder(tf.int32, shape=[1])
buffer = tf.Variable(tf.zeros([5], tf.int32))
index = tf.Variable(tf.constant(0))
def fillBufferBeforeFilled():
update_op1 = tf.scatter_update(buffer, indices=[index], updates=buffer_input)
index_assign_add = tf.assign_add(index, 1)
return update_op1, index_assign_add
def fillBufferAfterFilled():
tmp = tf.slice(buffer, begin=[0], size=[4])
update_op2 = tf.scatter_update(buffer, indices=[0, 1, 2, 3], updates=tmp)
update_op3 = tf.scatter_update(buffer, indices=[index], updates=buffer_input)
return update_op2, update_op3
cond = tf.cond(tf.equal(index, 4), lambda: fillBufferBeforeFilled(), lambda: fillBufferAfterFilled())
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(dim):
cond_ = sess.run(cond, feed_dict={buffer_input: data[i]})
buf = sess.run(buffer, feed_dict={buffer_input: data[i]})
print('buf: ', buf)
变量都不会递增,而index
的第一个元素将被赋值给传递给占位符的值。
我想知道为什么我会遇到这种行为,这个问题的解决方案是什么。
非常感谢任何帮助!!
答案 0 :(得分:0)
你混淆了tf.cond
中条件的顺序;它应该是
cond = tf.cond(tf.equal(index, 4), lambda: fillBufferAfterFilled(), lambda: fillBufferBeforeFilled())
我可以让你的代码运行,它主要起作用,但更新不太正确;我怀疑你需要添加一些tf.control_dependencies
次调用以强制事情以正确的顺序发生。
答案 1 :(得分:0)
以下是解决方案:
import tensorflow as tf
import numpy as np
import random
dim = 30
lst = []
for i in range(dim):
lst.append(random.randint(1, 10))
data = np.reshape(lst, [dim, 1])
print(lst)
# create a buffer:
buffer_input = tf.placeholder(tf.int32, shape=[1])
buffer = tf.Variable(tf.zeros([5], tf.int32))
index = tf.Variable(-1, tf.int32)
def fillBufferBeforeFilled():
index_assign_add = tf.assign_add(index, 1)
with tf.control_dependencies([index_assign_add]):
update_op1 = tf.scatter_update(buffer, indices=[index], updates=buffer_input)
return update_op1, index_assign_add
def fillBufferAfterFilled():
tmp = tf.slice(buffer, begin=[1], size=[4])
update_op2 = tf.scatter_update(buffer, indices=[0, 1, 2, 3], updates=tmp)
with tf.control_dependencies([update_op2]):
update_op3 = tf.scatter_update(buffer, indices=[index], updates=buffer_input)
return update_op2, update_op3
cond = tf.cond(tf.equal(index, 4), lambda: fillBufferAfterFilled(), lambda: fillBufferBeforeFilled())
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(dim):
cond_ = sess.run(cond, feed_dict={buffer_input: data[i]})
buf = sess.run(buffer, feed_dict={buffer_input: data[i]})
print('buf: ', buf)