在张量流中使用tf.PriorityQueue
似乎有问题。该文档说,shapes参数对于初始化不是必需的。我无法指定张量的形状,因为它是动态的并且形状是在运行时确定的。
来自tf.PriorityQueue上的tensorflow文档:
__init__(capacity,types,shapes=None,names=None,shared_name=None,name='priority_queue')
Args:
容量:一个整数。此队列中可能存储的元素数的上限。
types:DType对象的列表。类型的长度必须等于 除第一个优先级元素外,每个队列元素中的张量。每个元素中的第一个张量是优先级,必须为int64类型。
shapes :(可选。)完整定义的TensorShape对象的列表,其长度与类型相同,或者为None。
names :(可选。)一列字符串,用于命名队列中与dtypes相同的长度或无的字符串。如果指定,出队方法将返回一个字典,其名称为键。
shared_name :(可选。)如果为非空,则该队列将以给定名称在多个会话中共享。
name:队列操作的可选名称。
但是,以下代码会产生TypeError:
def build_queue():
with tf.name_scope("Queue"):
q = tf.PriorityQueue(capacity=2,types=tf.uint8,name="iq",shared_name="queue")
return q
File "C:\Users\devar\Documents\EngProj\SSPlayer\test\dist_cnn.py", line 212, in create_model
infer_q = build_infer_queue()
File "C:\Users\devar\Documents\EngProj\SSPlayer\test\dist_cnn.py", line 143, in build_queue
shared_name="queue")
File "C:\Users\devar\Envs\RL\lib\site-packages\tensorflow\python\ops\data_flow_ops.py", line 903, in __init__
name=name)
File "C:\Users\devar\Envs\RL\lib\site-packages\tensorflow\python\ops\gen_data_flow_ops.py", line 3409, in priority_queue_v2
TypeError: Expected list for 'shapes' argument to 'priority_queue_v2' Op, not None.
关于我在做什么错的任何想法吗?
答案 0 :(得分:0)
根据priority_queue_v2
中gen_data_flow_ops.py
中的注释,似乎必须存在。
Args: 形状:形状列表(每个
tf.TensorShape
或ints
列表)。 值中每个组件的形状。该属性的长度必须 为0或与component_types的长度相同。如果长度 此attr为0,队列元素的形状不受限制,并且 一次只能使一个元素出队。
它接受shapes=[()]
或shapes=[tf.TensorShape([])]
,但没有该错误,我会看到相同的错误。
import tensorflow as tf
import threading
list = tf.placeholder(tf.int64,name="x")
def build_queue():
with tf.name_scope("Queue"):
q = tf.PriorityQueue(2,tf.int64,shapes=[()],name="iq",shared_name="queue")
return q
queue = build_queue()
enqueue_op = queue.enqueue_many([list,list])
dequeue_op = queue.dequeue()
data_batch = tf.train.batch([dequeue_op], batch_size=2, capacity=40)
init_op = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer(), tf.initialize_local_variables())
sess = tf.Session()
def put():
sess.run(enqueue_op, feed_dict={list: [1,2,3,4,5]})
mythread = threading.Thread(target=put, args=())
mythread.start()
tf.train.start_queue_runners(sess)
try:
while True:
print (sess.run(data_batch))
except tf.errors.OutOfRangeError:
print ("Queue empty")
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run( build_queue() )