我创建了一个数据生成器传递给Keras的fit_generator()
,但数据量并不总是批量大小和{的 precise {1}},所以当一个纪元结束时,我想确保数据生成器重置并再次从(HDF5)数据集的开头开始。我的ML分类器是LSTM,这是时序数据,因此顺序很重要。
我知道当您使用内置在python类中的数据生成器时,可以使用steps_per_epoch
回调函数,例如Afshine Amidi的文章A detailed example of how to use data generators in Keras中描述的那样,但是我永远无法生成器工作的类版本。无论如何,我都有一个 巨大 数据集,即使在PyCharm中,我的jupyter内核快要死了,也存在内存错误,并且喜欢使用更简单的版本数据生成器。
这是我要传递给Keras的on_epoch_end()
的数据生成器。注意,在循环数据生成器的每次迭代中,该行代码都会显示用于从HDF5数据集对象提取数据的索引。如果您看一下这行代码在spew中的输出,您会注意到索引与历元不同步:
fit_generator()
这是我在运行Keras时得到的东西:
def data_generator(dataID,
batch_size,
dim = (20, 100)):
print("\nIn data_generator.\n")
DataDir = "data/"
BatchDir = DataDir + dataID + "/"
h5path = BatchDir + dataID + ".h5"
f = h5py.File(h5path, "r")
data = f["sigccm"]
labels = f["Attack"]
# number of records to process
nrecs = len(data)
outputshape = (batch_size, *dim)
while True:
for i in range(0, nrecs, batch_size):
'Generate one batch of data'
if i + batch_size > nrecs:
# upperbound = nrecs
# If we can get a complete batch
# out of the remaining data, go
# ahead and wrap up this epoch and
# start the next one.
break
else:
upperbound = i + batch_size
print("\ndata[%d : %d]\n" % (i, upperbound)) # <<<<< KEY LINE of CODE
X = np.array(data[i : upperbound])
y = np.array(labels[i : upperbound])
if outputshape != X.shape:
msg = "Wrong shape: "
idx = "index = {:d}, ".format(index)
shp = "X.shape = {:s}".format(str(X.shape))
msg = msg + idx + shp
print(msg)
else:
# Label (Attack) field has an extra
# nested array dimension. Get rid of it.
u = np.array(y)
y = np.resize(u, (batch_size, 1))
yield X, y
f.close()
Spew到此结束,因为内核死在jupyter和PyCharm中,我在数据生成器的以下代码行中遇到了MemoryError:
In data_generator.
data[0 : 200]
Epoch 1/10
data[200 : 400]
data[400 : 600]
1/24 [>.............................] - ETA: 1:33 - loss: 0.2952 - acc: 0.1350
data[600 : 800]
2/24 [=>............................] - ETA: 1:10 - loss: 0.2234 - acc: 0.4975
data[800 : 1000]
3/24 [==>...........................] - ETA: 1:01 - loss: 0.1923 - acc: 0.6200
data[1000 : 1200]
4/24 [====>.........................] - ETA: 55s - loss: 0.1753 - acc: 0.6813
data[1200 : 1400]
5/24 [=====>........................] - ETA: 49s - loss: 0.1652 - acc: 0.7170
data[1400 : 1600]
6/24 [======>.......................] - ETA: 45s - loss: 0.1587 - acc: 0.7400
data[1600 : 1800]
7/24 [=======>......................] - ETA: 42s - loss: 0.1536 - acc: 0.7571
data[1800 : 2000]
8/24 [=========>....................] - ETA: 39s - loss: 0.1496 - acc: 0.7700
data[2000 : 2200]
9/24 [==========>...................] - ETA: 36s - loss: 0.1469 - acc: 0.7794
data[2200 : 2400]
10/24 [===========>..................] - ETA: 34s - loss: 0.1447 - acc: 0.7870
data[2400 : 2600]
11/24 [============>.................] - ETA: 31s - loss: 0.1426 - acc: 0.7936
data[2600 : 2800]
12/24 [==============>...............] - ETA: 29s - loss: 0.1411 - acc: 0.7988
data[2800 : 3000]
13/24 [===============>..............] - ETA: 26s - loss: 0.1395 - acc: 0.8035
data[3000 : 3200]
14/24 [================>.............] - ETA: 24s - loss: 0.1385 - acc: 0.8071
data[3200 : 3400]
15/24 [=================>............] - ETA: 21s - loss: 0.1375 - acc: 0.8103
data[3400 : 3600]
16/24 [===================>..........] - ETA: 19s - loss: 0.1365 - acc: 0.8134
data[3600 : 3800]
17/24 [====================>.........] - ETA: 17s - loss: 0.1358 - acc: 0.8159
data[3800 : 4000]
18/24 [=====================>........] - ETA: 14s - loss: 0.1351 - acc: 0.8181
data[4000 : 4200]
19/24 [======================>.......] - ETA: 12s - loss: 0.1344 - acc: 0.8203
data[4200 : 4400]
20/24 [========================>.....] - ETA: 9s - loss: 0.1339 - acc: 0.8220
data[4400 : 4600]
21/24 [=========================>....] - ETA: 7s - loss: 0.1334 - acc: 0.8236
data[4600 : 4800]
22/24 [==========================>...] - ETA: 4s - loss: 0.1327 - acc: 0.8255
data[4800 : 5000]
23/24 [===========================>..] - ETA: 2s - loss: 0.1325 - acc: 0.8265
data[0 : 200]
In data_generator.
data[0 : 200]
data[200 : 400]
data[200 : 400]
data[400 : 600]
data[400 : 600]
data[600 : 800]
data[600 : 800]
data[800 : 1000]
data[800 : 1000]
data[1000 : 1200]
24/24 [==============================] - 75s 3s/step - loss: 0.1318 - acc: 0.8281 - val_loss: 0.1190 - val_acc: 0.8625
Epoch 2/10
data[1200 : 1400]
1/24 [>.............................] - ETA: 43s - loss: 0.1242 - acc: 0.8550
data[1400 : 1600]
2/24 [=>............................] - ETA: 46s - loss: 0.1209 - acc: 0.8600
data[1600 : 1800]
3/24 [==>...........................] - ETA: 46s - loss: 0.1208 - acc: 0.8600
data[1800 : 2000]
4/24 [====>.........................] - ETA: 45s - loss: 0.1199 - acc: 0.8613
data[2000 : 2200]
5/24 [=====>........................] - ETA: 43s - loss: 0.1194 - acc: 0.8620
data[2200 : 2400]
6/24 [======>.......................] - ETA: 41s - loss: 0.1196 - acc: 0.8617
data[2400 : 2600]
7/24 [=======>......................] - ETA: 40s - loss: 0.1202 - acc: 0.8607
data[2600 : 2800]
8/24 [=========>....................] - ETA: 37s - loss: 0.1202 - acc: 0.8606
data[2800 : 3000]
9/24 [==========>...................] - ETA: 35s - loss: 0.1203 - acc: 0.8606
data[3000 : 3200]
10/24 [===========>..................] - ETA: 33s - loss: 0.1206 - acc: 0.8600
data[3200 : 3400]
11/24 [============>.................] - ETA: 30s - loss: 0.1209 - acc: 0.8595
data[3400 : 3600]
12/24 [==============>...............] - ETA: 28s - loss: 0.1209 - acc: 0.8596
data[3600 : 3800]
13/24 [===============>..............] - ETA: 26s - loss: 0.1211 - acc: 0.8592
data[3800 : 4000]
14/24 [================>.............] - ETA: 23s - loss: 0.1211 - acc: 0.8593
data[4000 : 4200]
15/24 [=================>............] - ETA: 21s - loss: 0.1213 - acc: 0.8590
data[4200 : 4400]
16/24 [===================>..........] - ETA: 19s - loss: 0.1215 - acc: 0.8588
data[4400 : 4600]
17/24 [====================>.........] - ETA: 16s - loss: 0.1214 - acc: 0.8588
data[4600 : 4800]
18/24 [=====================>........] - ETA: 14s - loss: 0.1215 - acc: 0.8586
data[4800 : 5000]
19/24 [======================>.......] - ETA: 11s - loss: 0.1217 - acc: 0.8584
data[0 : 200]
20/24 [========================>.....] - ETA: 9s - loss: 0.1216 - acc: 0.8585
data[200 : 400]
21/24 [=========================>....] - ETA: 7s - loss: 0.1217 - acc: 0.8583
data[400 : 600]
22/24 [==========================>...] - ETA: 4s - loss: 0.1218 - acc: 0.8582
data[600 : 800]
23/24 [===========================>..] - ETA: 2s - loss: 0.1216 - acc: 0.8585
data[800 : 1000]
data[0 : 200]
所以,我有两个问题: