TensorFlow需要很长时间才能将数据加载到tf.Dataset

时间:2018-08-13 00:38:02

标签: tensorflow

我正在使用TensorFlow 1.9训练图像数据集,该数据集太大而无法从硬盘驱动器加载到RAM中。因此,我将数据集分成了两半。我想知道在整个数据集中进行训练的最有效方法是什么。

我的GPU有3 GB的内存,而我的RAM有32 GB的内存。每个半数据集的大小为20 GB。我的硬盘驱动器具有足够的可用空间(超过1 TB)。

我的尝试如下。我创建了一个可初始化的tf.Dataset,然后在每个时期都将其初始化两次:对数据集的每半部分一次。这样,每个时期都可以看到整个数据集,但任何时候都只需将一半的数据加载到RAM中即可。

但是,这非常慢,因为从我的硬盘驱动器加载数据需要花费很长时间,并且每次用此数据初始化数据集也需要很长时间。

是否有更有效的方法?

在加载数据集的另一半之前,我曾尝试在数据集的每一半上训练多个时期,这要快得多,但这会使验证数据的性能大大降低。据推测,这是因为该模型在每半部分都过度拟合,然后没有推广到另一半中的数据。

在下面的代码中,我创建并保存了一些测试数据,然后如上所述进行加载。加载每个半数据集的时间约为5秒,而使用此数据初始化数据集的时间约为1秒。这看起来似乎只是小数目,但所有这些总和要经过多个时期。实际上,我的计算机花费在数据加载上的时间几乎与它在数据上进行实际训练所花费的时间一样。

import tensorflow as tf
import numpy as np
import time

# Create and save 2 datasets of test NumPy data
dataset_num_elements = 100000
element_dim = 10000
batch_size = 50
test_data = np.zeros([2, int(dataset_num_elements * 0.5), element_dim], dtype=np.float32)
np.savez('test_data_1.npz', x=test_data[0])
np.savez('test_data_2.npz', x=test_data[1])

# Create the TensorFlow dataset
data_placeholder = tf.placeholder(tf.float32, [int(dataset_num_elements * 0.5), element_dim])
dataset = tf.data.Dataset.from_tensor_slices(data_placeholder)
dataset = dataset.shuffle(buffer_size=dataset_num_elements)
dataset = dataset.repeat()
dataset = dataset.batch(batch_size=batch_size)
dataset = dataset.prefetch(1)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
init_op = iterator.initializer

num_batches = int(dataset_num_elements / batch_size)

with tf.Session() as sess:
    while True:
        for dataset_section in range(2):
            # Load the data from the hard drive
            t1 = time.time()
            print('Loading')
            loaded_data = np.load('test_data_' + str(dataset_section + 1) + '.npz')
            x = loaded_data['x']
            print('Loaded')
            t2 = time.time()
            loading_time = t2 - t1
            print('Loading time = ' + str(loading_time))
            # Initialize the dataset with this loaded data
            t1 = time.time()
            sess.run(init_op, feed_dict={data_placeholder: x})
            t2 = time.time()
            initialization_time = t2 - t1
            print('Initialization time = ' + str(initialization_time))
            # Read the data in batches
            for i in range(num_batches):
                x = sess.run(next_element)

1 个答案:

答案 0 :(得分:2)

进纸不是输入数据的有效方法。您可以这样输入数据:

  1. 创建一个包含所有输入文件名的文件名数据集。您可以随机播放,在此处重复数据集。
  2. 将此数据集映射到数据,映射功能是读取,解码,转换图像。使用多线程进行地图转换。
  3. 预取要训练的数据。

这只是示例方式。您可以设计自己的管道,请记住以下几点:

  • 尽可能使用轻量级Feed
  • 使用多线程读取和预处理
  • 预取数据进行训练