我想使用一个无法加载到内存中的大型数据集来训练TensorFlow模型。但我不知道应该做些什么。
我已经阅读了一些关于TFRecords
文件格式和官方文档的精彩帖子。巴士我还是想不出来。
TensorFlow是否有完整的解决方案计划?
答案 0 :(得分:2)
考虑使用tf.TextLineReader
与tf.train.string_input_producer
结合使用,允许您从磁盘上的多个文件加载数据(如果您的数据集足够大,需要将其分散到多个文件中)。
请参阅https://www.tensorflow.org/programmers_guide/reading_data#reading_from_files
上述链接中的代码段:
filename_queue = tf.train.string_input_producer(["file0.csv", "file1.csv"])
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
# Default values, in case of empty columns. Also specifies the type of the
# decoded result.
record_defaults = [[1], [1], [1], [1], [1]]
col1, col2, col3, col4, col5 = tf.decode_csv(
value, record_defaults=record_defaults)
features = tf.stack([col1, col2, col3, col4])
with tf.Session() as sess:
# Start populating the filename queue.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for filename_queue = tf.train.string_input_producer(["file0.csv", "file1.csv"])
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
# Default values, in case of empty columns. Also specifies the type of the
# decoded result.
record_defaults = [[1], [1], [1], [1], [1]]
col1, col2, col3, col4, col5 = tf.decode_csv(
value, record_defaults=record_defaults)
features = tf.stack([col1, col2, col3, col4])
with tf.Session() as sess:
# Start populating the filename queue.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(1200):
# Retrieve a single instance:
example, label = sess.run([features, col5])
coord.request_stop()
coord.join(threads)i in range(1200):
# Retrieve a single instance:
example, label = sess.run([features, col5])
coord.request_stop()
coord.join(threads)
答案 1 :(得分:1)
通常,您仍然使用批量培训,以便您可以即时加载数据。例如图像:
for bid in nrBatches:
batch_x, batch_y = load_data_from_hd(bid)
train_step.run(feed_dict={x: batch_x, y_: batch_y})
因此,您即时加载每个批处理,只加载您在任何给定时刻需要加载的数据。当然,在使用硬盘而不是内存来加载数据时,您的训练时间会增加。