在我们拥有庞大的数据集时训练模型的方法,以避免错误记忆

时间:2017-08-03 17:41:12

标签: keras

我正在尝试使用具有更多250k图像的数据集来训练神经网络。但是由于我的有限计算机具有16 GB RAM和32 GB SWAP,我陷入困境。在我加载所有图像之前,它给了我一个内存错误。所以我想知道是否有办法使用我拥有的所有图像来训练我的神经网络?例如,我们可以将它加载到可用空间磁盘上而不是使用RAM内存来加载numpy数组中的图像吗?

编辑1:

def get_array_image(file, path):
    return cv2.imread(path+file)
def generator(features, labels, num_classes, batch_size, path=''):
     # Create empty arrays to contain batch of features and labels#
     batch_features = np.zeros((batch_size, 28, 28, 3))
     batch_labels = np.zeros((batch_size, 1))
     while True:
       for cpt in range(0, len(features), batch_size):
         for i in range(0, batch_size):
             index = cpt + i
             #print('images : ', index)
             batch_features[i] = get_array_image(features[index], path)
             batch_labels[i] = labels[index]
         yield batch_features, keras.utils.to_categorical(batch_labels, num_classes)

这是我与fit_generator函数一起使用的生成器。但我有准确性的问题。事实上,我在具有小神经网络的mnist数据集上尝试了它。如果我使用fit函数加载所有图像(大约60k图像用于训练,60k用于测试)我在一个纪元后有大约0.68的准确度。但是使用fit_generator我只获得0.1。我的发电机出了什么问题吗?当我打印索引变量时,它似乎很好。

编辑2:我解决了我的问题,但我不明白为什么会这样。事实上,当我在循环外部创建数组时,我获得了较低的精度,但是当它在内部时,fit_generator的精度很高。有人知道我错过了什么吗?

def generator(features, labels, num_classes, batch_size, path='', dtype=np.uint8):
     # Create empty arrays to contain batch of features and labels#
     # batch_features = np.ndarray(shape=(batch_size, 28, 28, 3), dtype=dtype)
     # batch_labels =  np.ndarray(shape=(batch_size, 1), dtype=dtype)
     while True:
       for cpt in range(0, len(features), batch_size):

         batch_features = np.ndarray(shape=(batch_size, 28, 28, 3), dtype=dtype)
         batch_labels =  np.ndarray(shape=(batch_size, 1), dtype=dtype)
         for i in range(0, batch_size):
             # index= random.randint(0, len(features)-1)
             index = cpt + i
             #print('images : ', index)
             batch_features[i] = get_array_image(features[index], path)
             batch_labels[i] = labels[index]
#             print(batch_labels[i])
#             cv2.imshow('image', batch_features[i])
#             cv2.waitKey(0)
#             cv2.destroyAllWindows()
             # print(features[index])

         print(batch_features.shape)
         yield batch_features, keras.utils.to_categorical(batch_labels, num_classes)

0 个答案:

没有答案