解决Keras中巨大数据集的内存问题

时间:2018-01-08 11:21:32

标签: tensorflow machine-learning deep-learning keras imagenet

我有这个代码,它使用Keras 2.0.1及其生成器,通过深度学习进行图像识别。当在GPU上运行时,此代码目前可以很好地处理1500个图像,但是当我开始使用50k图像进行评估时,我得到内存问题。我正在使用目录中的流来读取图像,使用predict_generator来获取预测和概率。下面是我得到的错误和我正在使用的代码:

ERROR:

2018-01-08 12:28:49.940361: E tensorflow/stream_executor/cuda/cuda_driver.cc:955] failed to alloc 17179869184 bytes on host: CUDA_ERROR_OUT_OF_MEMORY
2018-01-08 12:28:49.968880: W ./tensorflow/core/common_runtime/gpu/pool_allocator.h:195] could not allocate pinned host memory of size: 17179869184

CODE

from __future__ import division
import numpy as np
from keras import applications
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten

top_model_weights_path = '/home/rehan/ethnicity.071217.23-0.28.hdf5'
path = "/home/rehan/countries/pakistan/guys/test/"
img_width, img_height = 139, 139

confidence = 0.8

model = applications.InceptionResNetV2(include_top=False, weights='imagenet',
                                       input_shape=(img_width, img_height, 3))

print("base pretrained model loaded")



validation_generator = ImageDataGenerator(rescale=1./255).flow_from_directory(path, target_size=(img_width, img_height),
                                                        batch_size=6,shuffle=False)

print("generator built")


features = model.predict_generator(validation_generator)

print("features found")

model = Sequential()
model.add(Flatten(input_shape=(3, 3, 1536)))
model.add(Dense(256, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(6, activation='softmax'))
model.load_weights(top_model_weights_path)
print("top model loaded")
prediction_proba = model.predict_proba(features)
prediction_classes = model.predict_classes(features)
print(prediction_proba)
print(prediction_classes)
print("original file names")
print(validation_generator.filenames)

1 个答案:

答案 0 :(得分:0)

与此相关的keras github页面上有一个未解决的问题:https://github.com/keras-team/keras/issues/5835

也许您可以尝试将图像加载为float32而不是doubles / float64(默认值),这会使您的内存需求减半。