我是tensorflow的新手,我现在正在学习如何使用队列运行器。我想要做的是从dir读取二进制文件并使每个文件成为一个数组。我使用两个线程并批量生成4个数组。代码如下。
import glob
import tensorflow as tf
def readfile(filenames_queue):
filename = filenames_queue.dequeue()
value_strings = tf.read_file(filename)
array = tf.decode_raw(value_strings,tf.uint8)
return [array]
def input_pipeline(filenames,batch_size,num_threads=2):
filenames_queue = tf.train.string_input_producer(filenames)
thread_lists = [readfile(filenames_queue) for _ in range(num_threads)]
min_after_dequeue = 1000
capacity = min_after_dequeue+3*batch_size
arrays = tf.train.shuffle_batch_join(thread_lists,batch_size,capacity,min_after_dequeue)
return arrays
if __name__ == "__main__":
filenames = glob.glob('dir/*')
arrays_batch = input_pipeline(filenames,4)
with tf.Session() as sess:
tf.global_variables_initializer().run()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess,coord)
for i in range(100):
print sess.run(arrays_batch)
coord.request_stop()
coord.join(threads)
我已经修正了Victor和Sorin指出的错误,但是出现了新的错误:
文件“input_queue.py”,第36行,打印sess.run(im_arrays_batch)
文件“/usr/local/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.py”,第889行,在运行中 run_metadata_ptr)
文件“/usr/local/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.py”,第1120行,在_run中 feed_dict_tensor,options,run_metadata)
文件“/usr/local/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.py”,第1317行,在_do_run中 options,run_metadata)
文件“/usr/local/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.py”,第1336行,在_do_call 提升类型(e)(node_def,op,message) tensorflow.python.framework.errors_impl.OutOfRangeError:RandomShuffleQueue'_1_shuffle_batch_join / random_shuffle_queue'已关闭且元素不足(请求2,当前大小为0) [[Node:shuffle_batch_join = QueueDequeueManyV2 [component_types = [DT_UINT8],timeout_ms = -1,_device =“/ job:localhost / replica:0 / task:0 / device:CPU:0”](shuffle_batch_join / random_shuffle_queue,shuffle_batch_join / n )]]
由op u'shuffle_batch_join'引起,定义于:
文件“input_queue.py”,第30行,中 im_arrays_batch = input_pipeline(filenames,2)
在input_pipeline中输入第23行的“input_queue.py” arrays_batch = tf.train.shuffle_batch_join(thread_lists,batch_size,capacity,min_after_dequeue)
文件“/usr/local/anaconda2/lib/python2.7/site-packages/tensorflow/python/training/input.py”,第1367行,在shuffle_batch_join中 名称=名)
文件“/usr/local/anaconda2/lib/python2.7/site-packages/tensorflow/python/training/input.py”,第833行,在_shuffle_batch_join中 dequeued = queue.dequeue_many(batch_size,name = name)
文件“/usr/local/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/data_flow_ops.py”,第464行,在dequeue_many self._queue_ref,n = n,component_types = self._dtypes,name = name)
文件“/usr/local/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/gen_data_flow_ops.py”,第2418行,在_queue_dequeue_many_v2中 component_types = component_types,timeout_ms = timeout_ms,name = name)
文件“/usr/local/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py”,第787行,在_apply_op_helper中 op_def = op_def)
文件“/usr/local/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py”,第2956行,在create_op中 op_def = op_def)
文件“/usr/local/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py”,第1470行, init self._traceback = self._graph._extract_stack()#pylint:disable = protected-access
OutOfRangeError(参见上面的回溯):RandomShuffleQueue'_1_shuffle_batch_join / random_shuffle_queue'已关闭且元素不足(请求2,当前大小为0) [[Node:shuffle_batch_join = QueueDequeueManyV2 [component_types = [DT_UINT8],timeout_ms = -1,_device =“/ job:localhost / replica:0 / task:0 / device:CPU:0”](shuffle_batch_join / random_shuffle_queue,shuffle_batch_join / n )]]
答案 0 :(得分:0)
您的readfile(...):
函数应该返回一个iterable,以便您可以返回功能和标签或其他类似的内容。
所以要修改代码更改readfile(...):
到
return [arrays]
答案 1 :(得分:0)
来自tf.train.shuffle_batch_join:
tensors_list参数是张量元组的列表
在此致电tf.decode_raw
produces Tensor
instances,并将其列入thread_lists = [readfile(filenames_queue) for _ in range(num_threads)]
的列表中。
因此,不是你提供的张量元组列表,而是张量列表,因此张量试图被迭代,因此错误TypeError: 'Tensor' object is not iterable
。