如何将LMDB文件加载到TensorFlow中?

时间:2016-05-20 03:36:30

标签: machine-learning tensorflow

我有一个大的(1 TB)数据集,分为大约3,000个CSV文件。我的计划是将其转换为一个大的LMDB文件,以便可以快速读取它以训练神经网络。但是,我无法找到有关如何将LMDB文件加载到TensorFlow的任何文档。有谁知道如何做到这一点?我知道TensorFlow可以读取CSV文件,但我相信这会太慢。

1 个答案:

答案 0 :(得分:7)

根据this,有几种方法可以在TensorFlow中读取数据。

最简单的方法是通过占位符提供数据。当使用占位符时 - 洗牌和批处理的责任在你身上。

如果要将shuffling和batching委托给框架,则需要创建输入管道。问题是这样 - 如何将lmdb数据注入符号输入管道。可能的解决方案是使用tf.py_func操作。这是一个例子:

def create_input_pipeline(lmdb_env, keys, num_epochs=10, batch_size=64):
   key_producer = tf.train.string_input_producer(keys, 
                                                 num_epochs=num_epochs,
                                                 shuffle=True)
   single_key = key_producer.dequeue()

   def get_bytes_from_lmdb(key):
      with lmdb_env.begin() as txn:
         lmdb_val = txn.get(key)
      example = get_example_from_val(lmdb_val) # A single example (numpy array)
      label = get_label_from_val(lmdb_val)     # The label, could be a scalar
      return example, label

   single_example, single_label = tf.py_func(get_bytes_from_lmdb,
                                             [single_key], [tf.float32, tf.float32])
   # if you know the shapes of the tensors you can set them here:
   # single_example.set_shape([224,224,3])

   batch_examples, batch_labels = tf.train.batch([single_example, single_label],
                                                 batch_size)
   return batch_examples, batch_labels

tf.py_func op在 TensorFlow 图中插入对常规python代码的调用,我们需要指定输入以及输出的数量和类型。 tf.train.string_input_producer使用给定的密钥创建一个混洗队列。 tf.train.batch op创建另一个包含批量数据的队列。在培训时,batch_examplesbatch_labels的每次评估都会从该队列中取出另一批次。

因为我们创建了队列,所以在开始训练之前我们需要注意并运行QueueRunner对象。这样做(来自 TensorFlow doc):

# Create the graph, etc.
init_op = tf.initialize_all_variables()

# Create a session for running operations in the Graph.
sess = tf.Session()

# Initialize the variables (like the epoch counter).
sess.run(init_op)

# Start input enqueue threads.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

try:
    while not coord.should_stop():
        # Run training steps or whatever
        sess.run(train_op)

except tf.errors.OutOfRangeError:
    print('Done training -- epoch limit reached')
finally:
    # When done, ask the threads to stop.
    coord.request_stop()

# Wait for threads to finish.
coord.join(threads)
sess.close()