Keras的fit_generator应该在每个纪元后重置发电机吗?

时间:2018-02-11 06:51:50

标签: python-3.x machine-learning neural-network keras generator

我正在尝试使用fit_generator和自定义生成器来读取对于内存来说太大的数据。我想要训练125万行,所以我一次产生50,000行。 fit_generator有25 steps_per_epoch,我认为这会带来每个1.25MM的纪录。我添加了一个print语句,以便我可以看到该进程正在做多少偏移,并且当它进入epoch 2时,我发现它超过了最大值。该文件中总共有175万条记录,并且一次它传递了10个步骤,它在create_feature_matrix调用中得到一个索引错误(因为它没有引入任何行)。

def get_next_data_batch():
    import gc
    nrows = 50000
    skiprows = 0

    while True:
        d = pd.read_csv(file_loc,skiprows=range(1,skiprows),nrows=nrows,index_col=0)
        print(skiprows)
        x,y = create_feature_matrix(d)
        yield x,y
        skiprows = skiprows + nrows
        gc.collect()
get_data = get_next_data_batch()

... set up a Keras NN ...

model.fit_generator(get_next_data_batch(), epochs=100,steps_per_epoch=25,verbose=1,workers=4,callbacks=callbacks_list)

我使用fit_generator时出错了还是需要对我的自定义生成器进行一些更改才能使其正常工作?

1 个答案:

答案 0 :(得分:7)

否 - fit_generator不会重置生成器,它只是继续调用它。为了实现您想要的行为,您可以尝试以下方法:

def get_next_data_batch(nb_of_calls_before_reset=25):
    import gc
    nrows = 50000
    skiprows = 0
    nb_calls = 0

    while True:
        d = pd.read_csv(file_loc,skiprows=range(1,skiprows),nrows=nrows,index_col=0)
        print(skiprows)
        x,y = create_feature_matrix(d)
        yield x,y
        nb_calls += 1
        if nb_calls == nb_of_calls_before_reset:
            skiprows = 0
        else:
            skiprows = skiprows + nrows
        gc.collect()