我有一个很大的训练矩阵,以h5格式存储图像和相应的类。我想阅读图像及其标签,以便在Keras中训练模型。由于数据集太大,因此我创建了自己的数据生成器函数 imageLoader()并将其用于 model.fit_generator 中。 imageLoader()批量读取 batch_size 大小的训练数据,并将其用于训练。
训练矩阵(h5文件)包含以下数据:
火车功能:(58160 x 25 x 25 x 50)=>否。的图像x行x列x通道
Train_class:(58160 x 1)=>不。图片x标签
但是,在Keras中使用fit_generator时,仍然出现内存错误。以下是我的代码:
import os
os.environ['PYTHONHASHSEED'] = '0'
import numpy as np
np.random.seed(7)
from tensorflow import set_random_seed
set_random_seed(2)
import keras
from keras.layers import Input, Dense, Dropout, BatchNormalization
from keras.models import Model, Sequential
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, Flatten, Activation
from keras.callbacks import ReduceLROnPlateau
import h5py
h5f = h5py.File('path to h5 file', 'r')
#Data Generator
def imageLoader(h5f, batch_size):
L= len(h5f['train_class'])
while True:
batch_start=0
batch_end= batch_size
while batch_start<L:
limit = min(batch_end, L)
X= h5f['train_features'][batch_start:limit]
Y= h5f['train_class'][batch_start:limit]
print(batch_start, batch_end)
yield (X,keras.utils.to_categorical(Y,2)) #I have two classes
batch_start +=batch_size
batch_end +=batch_size
# Building Model
model= Sequential()
model.add(Conv2D(80,(4,4), activation='relu', input_shape= (25,25,50)))
model.add(MaxPooling2D(pool_size= (2,2)))
model.add(Conv2D(150,(4,4), activation = 'relu'))
model.add(MaxPooling2D(pool_size= (2,2)))
model.add(Conv2D(200,(4,4), activation = 'relu'))
model.add(Flatten())
model.add(Dense(2, activation = 'softmax'))
model.summary()
model.compile(loss = 'binary_crossentropy', optimizer = 'adam', metrics =
['accuracy'])
model.fit_generator(imageLoader(h5f,100),
steps_per_epoch=len(h5f['train_class'])/100, epochs=3)
这是错误日志:
Epoch 1/3
0 100
100 200
200 300
300 400
400 500
500 600
600 700
700 800
800 900
900 1000
1000 1100
1/54 [..............................] - ETA: 22s - loss: 0.9183 - acc:
0.11001100 1200
1200 1300
1300 1400
1400 1500
5/54 [=>............................] - ETA: 4s - loss: 0.9579 - acc: 0.0380
1500 1600
7/54 [==>...........................] - ETA: 42s - loss: 0.8970 - acc:
0.1429
Exception in thread Thread-28:
Traceback (most recent call last):
File "/home/saror/.pyenv/versions/anaconda3-4.4.0/envs/py3dl2/lib/python3.6/threading.py", line 916, in _bootstrap_inner
self.run()
File "/home/saror/.pyenv/versions/anaconda3-4.4.0/envs/py3dl2/lib/python3.6/threading.py", line 864, in run
self._target(*self._args, **self._kwargs)
File "/home/saror/.pyenv/versions/anaconda3- 4.4.0/envs/py3dl2/lib/python3.6/site-packages/keras/utils/data_utils.py",
line 579, in data_generator_task
generator_output = next(self._generator)
File "<ipython-input-119-7a1d852e803c>", line 10, in imageLoader
X= h5f['test_features'][batch_start:limit]
File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
File "/home/saror/.pyenv/versions/anaconda3-4.4.0/envs/py3dl2/lib/python3.6/site-packages/h5py/_hl/dataset.py", line 496, in __getitem__
self.id.read(mspace, fspace, arr, mtype, dxpl=self._dxpl)
File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
File "h5py/h5d.pyx", line 181, in h5py.h5d.DatasetID.read
File "h5py/_proxy.pyx", line 130, in h5py._proxy.dset_rw
File "h5py/_proxy.pyx", line 84, in h5py._proxy.H5PY_H5Dread
OSError: Can't read data (file read failed: time = Wed Oct 3 23:43:01 2018
, filename = '/mnt/hdd/shreya/change_detection_files/san_test_data.h5', file
descriptor = 86, errno = 5, error message = 'Input/output error', buf =
0x7f53b2990c80, total read size = 6415424, bytes this sub-read = 6415424,
bytes actually read = 18446744073709551615, offset = 418586624)
Traceback (most recent call last):
File "<ipython-input-120-e74f6fee0af9>", line 15, in <module>
model.fit_generator(imageLoader(h5f,100),
steps_per_epoch=len(h5f['test_class'])/100, epochs=3)
File "/home/saror/.pyenv/versions/anaconda3-4.4.0/envs/py3dl2/lib/python3.6/site-packages/keras/legacy/interfaces.py", line 87, in wrapper
return func(*args, **kwargs)
File "/home/saror/.pyenv/versions/anaconda3-4.4.0/envs/py3dl2/lib/python3.6/site-packages/keras/models.py", line 1156, in fit_generator
initial_epoch=initial_epoch)
File "/home/saror/.pyenv/versions/anaconda3-4.4.0/envs/py3dl2/lib/python3.6/site-packages/keras/legacy/interfaces.py", line 87, in wrapper
return func(*args, **kwargs)
File "/home/saror/.pyenv/versions/anaconda3-4.4.0/envs/py3dl2/lib/python3.6/site-packages/keras/engine/training.py", line 2046, in fit_generator
generator_output = next(output_generator)
StopIteration
任何人都可以说出上面代码中的错误吗?如何有效地读取大数据以进行模型训练?
谢谢!