如果我想使用无法通过TensorFlow加载到内存中的大型数据集,我该怎么办?

时间:2017-02-22 11:20:05

标签: python machine-learning tensorflow deep-learning bigdata

我想使用一个无法加载到内存中的大型数据集来训练TensorFlow模型。但我不知道应该做些什么。

我已经阅读了一些关于TFRecords文件格式和官方文档的精彩帖子。巴士我还是想不出来。

TensorFlow是否有完整的解决方案计划?

2 个答案:

答案 0 :(得分:2)

考虑使用tf.TextLineReadertf.train.string_input_producer结合使用,允许您从磁盘上的多个文件加载数据(如果您的数据集足够大,需要将其分散到多个文件中)。

请参阅https://www.tensorflow.org/programmers_guide/reading_data#reading_from_files

上述链接中的代码段:

filename_queue = tf.train.string_input_producer(["file0.csv", "file1.csv"])

reader = tf.TextLineReader()
key, value = reader.read(filename_queue)

# Default values, in case of empty columns. Also specifies the type of the
# decoded result.
record_defaults = [[1], [1], [1], [1], [1]]
col1, col2, col3, col4, col5 = tf.decode_csv(
    value, record_defaults=record_defaults)
features = tf.stack([col1, col2, col3, col4])

with tf.Session() as sess:
  # Start populating the filename queue.
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(coord=coord)

  for     filename_queue = tf.train.string_input_producer(["file0.csv", "file1.csv"])

reader = tf.TextLineReader()
key, value = reader.read(filename_queue)

# Default values, in case of empty columns. Also specifies the type of the
# decoded result.
record_defaults = [[1], [1], [1], [1], [1]]
col1, col2, col3, col4, col5 = tf.decode_csv(
    value, record_defaults=record_defaults)
features = tf.stack([col1, col2, col3, col4])

with tf.Session() as sess:
  # Start populating the filename queue.
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(coord=coord)

  for i in range(1200):
    # Retrieve a single instance:
    example, label = sess.run([features, col5])

  coord.request_stop()
  coord.join(threads)i in range(1200):
    # Retrieve a single instance:
    example, label = sess.run([features, col5])

  coord.request_stop()
  coord.join(threads)

答案 1 :(得分:1)

通常,您仍然使用批量培训,以便您可以即时加载数据。例如图像:

for bid in nrBatches:
     batch_x, batch_y = load_data_from_hd(bid)
     train_step.run(feed_dict={x: batch_x, y_: batch_y})

因此,您即时加载每个批处理,只加载您在任何给定时刻需要加载的数据。当然,在使用硬盘而不是内存来加载数据时,您的训练时间会增加。