在张量流

时间:2018-01-22 08:29:29

标签: python tensorflow

我是tensorflow的新手,我现在正在学习如何使用队列运行器。我想要做的是从dir读取二进制文件并使每个文件成为一个数组。我使用两个线程并批量生成4个数组。代码如下。

  import glob

  import tensorflow as tf

  def readfile(filenames_queue):

        filename = filenames_queue.dequeue()
        value_strings = tf.read_file(filename)
        array = tf.decode_raw(value_strings,tf.uint8)
        return [array]
 def input_pipeline(filenames,batch_size,num_threads=2):

       filenames_queue = tf.train.string_input_producer(filenames)
       thread_lists = [readfile(filenames_queue) for _ in range(num_threads)] 
       min_after_dequeue = 1000 
       capacity = min_after_dequeue+3*batch_size
       arrays = tf.train.shuffle_batch_join(thread_lists,batch_size,capacity,min_after_dequeue)
       return arrays
if __name__ == "__main__":

      filenames = glob.glob('dir/*')
      arrays_batch = input_pipeline(filenames,4)
      with tf.Session() as sess:
           tf.global_variables_initializer().run()
           coord = tf.train.Coordinator()
           threads = tf.train.start_queue_runners(sess,coord)
           for i in range(100):
                 print sess.run(arrays_batch)
           coord.request_stop()
           coord.join(threads)

我已经修正了Victor和Sorin指出的错误,但是出现了新的错误:

文件“input_queue.py”,第36行,打印sess.run(im_arrays_batch)

文件“/usr/local/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.py”,第889行,在运行中     run_metadata_ptr)

文件“/usr/local/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.py”,第1120行,在_run中     feed_dict_tensor,options,run_metadata)

文件“/usr/local/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.py”,第1317行,在_do_run中     options,run_metadata)

文件“/usr/local/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.py”,第1336行,在_do_call     提升类型(e)(node_def,op,message) tensorflow.python.framework.errors_impl.OutOfRangeError:RandomShuffleQueue'_1_shuffle_batch_join / random_shuffle_queue'已关闭且元素不足(请求2,当前大小为0)      [[Node:shuffle_batch_join = QueueDequeueManyV2 [component_types = [DT_UINT8],timeout_ms = -1,_device =“/ job:localhost / replica:0 / task:0 / device:CPU:0”](shuffle_batch_join / random_shuffle_queue,shuffle_batch_join / n )]]

由op u'shuffle_batch_join'引起,定义于:

文件“input_queue.py”,第30行,中     im_arrays_batch = input_pipeline(filenames,2)

在input_pipeline中输入第23行的“input_queue.py”     arrays_batch = tf.train.shuffle_batch_join(thread_lists,batch_size,capacity,min_after_dequeue)

文件“/usr/local/anaconda2/lib/python2.7/site-packages/tensorflow/python/training/input.py”,第1367行,在shuffle_batch_join中     名称=名)

文件“/usr/local/anaconda2/lib/python2.7/site-packages/tensorflow/python/training/input.py”,第833行,在_shuffle_batch_join中     dequeued = queue.dequeue_many(batch_size,name = name)

文件“/usr/local/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/data_flow_ops.py”,第464行,在dequeue_many     self._queue_ref,n = n,component_types = self._dtypes,name = name)

文件“/usr/local/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/gen_data_flow_ops.py”,第2418行,在_queue_dequeue_many_v2中     component_types = component_types,timeout_ms = timeout_ms,name = name)

文件“/usr/local/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py”,第787行,在_apply_op_helper中     op_def = op_def)

文件“/usr/local/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py”,第2956行,在create_op中     op_def = op_def)

文件“/usr/local/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py”,第1470行, init     self._traceback = self._graph._extract_stack()#pylint:disable = protected-access

OutOfRangeError(参见上面的回溯):RandomShuffleQueue'_1_shuffle_batch_join / random_shuffle_queue'已关闭且元素不足(请求2,当前大小为0)      [[Node:shuffle_batch_join = QueueDequeueManyV2 [component_types = [DT_UINT8],timeout_ms = -1,_device =“/ job:localhost / replica:0 / task:0 / device:CPU:0”](shuffle_batch_join / random_shuffle_queue,shuffle_batch_join / n )]]

2 个答案:

答案 0 :(得分:0)

您的readfile(...):函数应该返回一个iterable,以便您可以返回功能和标签或其他类似的内容。

所以要修改代码更改readfile(...):

return [arrays]

答案 1 :(得分:0)

来自tf.train.shuffle_batch_join

  

tensors_list参数是张量元组的列表

在此致电tf.decode_raw produces Tensor instances,并将其列入thread_lists = [readfile(filenames_queue) for _ in range(num_threads)]的列表中。

因此,不是你提供的张量元组列表,而是张量列表,因此张量试图被迭代,因此错误TypeError: 'Tensor' object is not iterable