确保传递给Keras的数据生成器`fit_generator()`在每个时期都重新开始

时间:2019-02-19 21:26:15

标签: python tensorflow keras out-of-memory data-generation

我创建了一个数据生成器传递给Keras的fit_generator(),但数据量并不总是批量大小和{的 precise {1}},所以当一个纪元结束时,我想确保数据生成器重置并再次从(HDF5)数据集的开头开始。我的ML分类器是LSTM,这是时序数据,因此顺序很重要。

我知道当您使用内置在python类中的数据生成器时,可以使用steps_per_epoch回调函数,例如Afshine Amidi的文章A detailed example of how to use data generators in Keras中描述的那样,但是我永远无法生成器工作的类版本。无论如何,我都有一个 巨大 数据集,即使在PyCharm中,我的jupyter内核快要死了,也存在内存错误,并且喜欢使用更简单的版本数据生成器。

这是我要传递给Keras的on_epoch_end()的数据生成器。注意,在循环数据生成器的每次迭代中,该行代码都会显示用于从HDF5数据集对象提取数据的索引。如果您看一下这行代码在spew中的输出,您会注意到索引与历元不同步:

fit_generator()

这是我在运行Keras时得到的东西:

def data_generator(dataID,
                   batch_size, 
                   dim = (20, 100)):      

    print("\nIn data_generator.\n")      

    DataDir = "data/"
    BatchDir = DataDir + dataID + "/"
    h5path = BatchDir + dataID + ".h5"

    f = h5py.File(h5path, "r")
    data = f["sigccm"]
    labels = f["Attack"]

    # number of records to process
    nrecs = len(data)

    outputshape = (batch_size, *dim)

    while True:            

        for i in range(0, nrecs, batch_size):

            'Generate one batch of data'

            if i + batch_size > nrecs:

                # upperbound = nrecs 

                # If we can get a complete batch
                # out of the remaining data, go
                # ahead and wrap up this epoch and
                # start the next one.

                break

            else: 
                upperbound = i + batch_size
                print("\ndata[%d : %d]\n" % (i, upperbound)) # <<<<< KEY LINE of CODE
                X = np.array(data[i : upperbound]) 
                y = np.array(labels[i : upperbound])             

            if outputshape != X.shape:
                msg = "Wrong shape: "
                idx = "index = {:d}, ".format(index)
                shp = "X.shape = {:s}".format(str(X.shape))
                msg = msg + idx + shp
                print(msg)

            else:

                # Label (Attack) field has an extra
                # nested array dimension. Get rid of it.

                u = np.array(y)
                y = np.resize(u, (batch_size, 1))
                yield X, y  

    f.close()

Spew到此结束,因为内核死在jupyter和PyCharm中,我在数据生成器的以下代码行中遇到了MemoryError:

In data_generator.
data[0 : 200]

Epoch 1/10
data[200 : 400]
data[400 : 600]

 1/24 [>.............................] - ETA: 1:33 - loss: 0.2952 - acc: 0.1350
data[600 : 800]

 2/24 [=>............................] - ETA: 1:10 - loss: 0.2234 - acc: 0.4975
data[800 : 1000]

 3/24 [==>...........................] - ETA: 1:01 - loss: 0.1923 - acc: 0.6200
data[1000 : 1200]

 4/24 [====>.........................] - ETA: 55s - loss: 0.1753 - acc: 0.6813 
data[1200 : 1400]

 5/24 [=====>........................] - ETA: 49s - loss: 0.1652 - acc: 0.7170
data[1400 : 1600]

 6/24 [======>.......................] - ETA: 45s - loss: 0.1587 - acc: 0.7400
data[1600 : 1800]

 7/24 [=======>......................] - ETA: 42s - loss: 0.1536 - acc: 0.7571
data[1800 : 2000]

 8/24 [=========>....................] - ETA: 39s - loss: 0.1496 - acc: 0.7700
data[2000 : 2200]

 9/24 [==========>...................] - ETA: 36s - loss: 0.1469 - acc: 0.7794
data[2200 : 2400]

10/24 [===========>..................] - ETA: 34s - loss: 0.1447 - acc: 0.7870
data[2400 : 2600]

11/24 [============>.................] - ETA: 31s - loss: 0.1426 - acc: 0.7936
data[2600 : 2800]

12/24 [==============>...............] - ETA: 29s - loss: 0.1411 - acc: 0.7988
data[2800 : 3000]

13/24 [===============>..............] - ETA: 26s - loss: 0.1395 - acc: 0.8035
data[3000 : 3200]

14/24 [================>.............] - ETA: 24s - loss: 0.1385 - acc: 0.8071
data[3200 : 3400]

15/24 [=================>............] - ETA: 21s - loss: 0.1375 - acc: 0.8103
data[3400 : 3600]

16/24 [===================>..........] - ETA: 19s - loss: 0.1365 - acc: 0.8134
data[3600 : 3800]

17/24 [====================>.........] - ETA: 17s - loss: 0.1358 - acc: 0.8159
data[3800 : 4000]

18/24 [=====================>........] - ETA: 14s - loss: 0.1351 - acc: 0.8181
data[4000 : 4200]

19/24 [======================>.......] - ETA: 12s - loss: 0.1344 - acc: 0.8203
data[4200 : 4400]

20/24 [========================>.....] - ETA: 9s - loss: 0.1339 - acc: 0.8220 
data[4400 : 4600]

21/24 [=========================>....] - ETA: 7s - loss: 0.1334 - acc: 0.8236
data[4600 : 4800]

22/24 [==========================>...] - ETA: 4s - loss: 0.1327 - acc: 0.8255
data[4800 : 5000]

23/24 [===========================>..] - ETA: 2s - loss: 0.1325 - acc: 0.8265
data[0 : 200]


In data_generator.
data[0 : 200]
data[200 : 400]
data[200 : 400]
data[400 : 600]
data[400 : 600]
data[600 : 800]
data[600 : 800]
data[800 : 1000]
data[800 : 1000]
data[1000 : 1200]

24/24 [==============================] - 75s 3s/step - loss: 0.1318 - acc: 0.8281 - val_loss: 0.1190 - val_acc: 0.8625
Epoch 2/10
data[1200 : 1400]

 1/24 [>.............................] - ETA: 43s - loss: 0.1242 - acc: 0.8550
data[1400 : 1600]

 2/24 [=>............................] - ETA: 46s - loss: 0.1209 - acc: 0.8600
data[1600 : 1800]

 3/24 [==>...........................] - ETA: 46s - loss: 0.1208 - acc: 0.8600
data[1800 : 2000]

 4/24 [====>.........................] - ETA: 45s - loss: 0.1199 - acc: 0.8613
data[2000 : 2200]

 5/24 [=====>........................] - ETA: 43s - loss: 0.1194 - acc: 0.8620
data[2200 : 2400]

 6/24 [======>.......................] - ETA: 41s - loss: 0.1196 - acc: 0.8617
data[2400 : 2600]

 7/24 [=======>......................] - ETA: 40s - loss: 0.1202 - acc: 0.8607
data[2600 : 2800]

 8/24 [=========>....................] - ETA: 37s - loss: 0.1202 - acc: 0.8606
data[2800 : 3000]

 9/24 [==========>...................] - ETA: 35s - loss: 0.1203 - acc: 0.8606
data[3000 : 3200]

10/24 [===========>..................] - ETA: 33s - loss: 0.1206 - acc: 0.8600
data[3200 : 3400]

11/24 [============>.................] - ETA: 30s - loss: 0.1209 - acc: 0.8595
data[3400 : 3600]

12/24 [==============>...............] - ETA: 28s - loss: 0.1209 - acc: 0.8596
data[3600 : 3800]

13/24 [===============>..............] - ETA: 26s - loss: 0.1211 - acc: 0.8592
data[3800 : 4000]

14/24 [================>.............] - ETA: 23s - loss: 0.1211 - acc: 0.8593
data[4000 : 4200]

15/24 [=================>............] - ETA: 21s - loss: 0.1213 - acc: 0.8590
data[4200 : 4400]

16/24 [===================>..........] - ETA: 19s - loss: 0.1215 - acc: 0.8588
data[4400 : 4600]

17/24 [====================>.........] - ETA: 16s - loss: 0.1214 - acc: 0.8588
data[4600 : 4800]

18/24 [=====================>........] - ETA: 14s - loss: 0.1215 - acc: 0.8586
data[4800 : 5000]

19/24 [======================>.......] - ETA: 11s - loss: 0.1217 - acc: 0.8584
data[0 : 200]

20/24 [========================>.....] - ETA: 9s - loss: 0.1216 - acc: 0.8585 
data[200 : 400]

21/24 [=========================>....] - ETA: 7s - loss: 0.1217 - acc: 0.8583
data[400 : 600]

22/24 [==========================>...] - ETA: 4s - loss: 0.1218 - acc: 0.8582
data[600 : 800]

23/24 [===========================>..] - ETA: 2s - loss: 0.1216 - acc: 0.8585
data[800 : 1000]

data[0 : 200]

所以,我有两个问题:

  1. 如何确保每个时期将数据生成器重置 并从数据集的开头重新开始?
  2. 任何关于如何避免内存错误和垂死的jupyter内核的猜测都值得欢迎,但是我认为我在某个地方遇到了硬限制,只需要使用较小的数据集即可。

0 个答案:

没有答案