如何在TensorFlow中使用PriorityQueue?

时间:2018-08-03 08:09:11

标签: python tensorflow

在张量流中使用tf.PriorityQueue似乎有问题。该文档说,shapes参数对于初始化不是必需的。我无法指定张量的形状,因为它是动态的并且形状是在运行时确定的。

来自tf.PriorityQueue上的tensorflow文档:

__init__(capacity,types,shapes=None,names=None,shared_name=None,name='priority_queue')
  

Args:

     

容量:一个整数。此队列中可能存储的元素数的上限。

     

types:DType对象的列表。类型的长度必须等于   除第一个优先级元素外,每个队列元素中的张量。每个元素中的第一个张量是优先级,必须为int64类型。

     

shapes :(可选。)完整定义的TensorShape对象的列表,其长度与类型相同,或者为None。

     

names :(可选。)一列字符串,用于命名队列中与dtypes相同的长度或无的字符串。如果指定,出队方法将返回一个字典,其名称为键。

     

shared_name :(可选。)如果为非空,则该队列将以给定名称在多个会话中共享。

     

name:队列操作的可选名称。

但是,以下代码会产生TypeError:

def build_queue():
    with tf.name_scope("Queue"):
        q = tf.PriorityQueue(capacity=2,types=tf.uint8,name="iq",shared_name="queue")
    return q

File "C:\Users\devar\Documents\EngProj\SSPlayer\test\dist_cnn.py", line 212, in create_model
    infer_q = build_infer_queue()
  File "C:\Users\devar\Documents\EngProj\SSPlayer\test\dist_cnn.py", line 143, in build_queue
    shared_name="queue")
  File "C:\Users\devar\Envs\RL\lib\site-packages\tensorflow\python\ops\data_flow_ops.py", line 903, in __init__
    name=name)
  File "C:\Users\devar\Envs\RL\lib\site-packages\tensorflow\python\ops\gen_data_flow_ops.py", line 3409, in priority_queue_v2

TypeError: Expected list for 'shapes' argument to 'priority_queue_v2' Op, not None.

关于我在做什么错的任何想法吗?

1 个答案:

答案 0 :(得分:0)

根据priority_queue_v2gen_data_flow_ops.py中的注释,似乎必须存在。

  

Args:       形状:形状列表(每个tf.TensorShapeints列表)。         值中每个组件的形状。该属性的长度必须         为0或与component_types的长度相同。如果长度         此attr为0,队列元素的形状不受限制,并且         一次只能使一个元素出队。

它接受shapes=[()]shapes=[tf.TensorShape([])],但没有该错误,我会看到相同的错误。

import tensorflow as tf
import threading

list = tf.placeholder(tf.int64,name="x")


def build_queue():
    with tf.name_scope("Queue"):
        q = tf.PriorityQueue(2,tf.int64,shapes=[()],name="iq",shared_name="queue")
    return q

queue = build_queue()

enqueue_op = queue.enqueue_many([list,list])

dequeue_op = queue.dequeue()

data_batch = tf.train.batch([dequeue_op], batch_size=2, capacity=40)

init_op = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer(), tf.initialize_local_variables())

sess = tf.Session()

def put():
    sess.run(enqueue_op, feed_dict={list: [1,2,3,4,5]})

mythread = threading.Thread(target=put, args=())
mythread.start()

tf.train.start_queue_runners(sess)

try:
    while True:
        print (sess.run(data_batch))
except tf.errors.OutOfRangeError:
    print ("Queue empty")


sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run( build_queue() )