Keras有状态LSTM fit_generator如何使用batch_size> 1

时间:2018-01-22 13:35:39

标签: python keras lstm stateful

我想使用Keras中的功能API训练有状态 LSTM网络。

拟合方法为fit_generator

我可以使用:batch_size = 1

训练它

我的输入图层是:

Input(shape=(n_history, n_cols),batch_shape=(batch_size, n_history, n_cols), 
    dtype='float32', name='daily_input')

生成器如下:

def training_data():
    while 1:       
        for i in range(0,pdf_daily_data.shape[0]-n_history,1):            
            x = f(i)() # f(i) shape is (1, n_history, n_cols)
            y = y(i)
            yield (x,y)

然后合适的是:

model.fit_generator(training_data(),
                    steps_per_epoch=pdf_daily_data.shape[0]//batch_size,...

这很有效,但是很慢,并且在batch_size = 1

之后的每个时间步都执行渐变更新

在此配置中,我如何设置batch_size > 1请记住:LSTM图层 stateful = True

1 个答案:

答案 0 :(得分:0)

您必须将您的生成器修改为yeld您想要批次所需的元素数量

目前,您正在按元素迭代数据元素(根据您的第三个range()参数),获得单个 xy,以及然后产生那个元素。当您返回单个元素时,您将获得batch_size=1,因为fit_generator是逐个元素训练的。

假设您希望批量大小为10,则必须对数据进行切片并获得每个10个元素的片段,并yield这些片段而不是单个元素。请确保相应地对输入图层的形状反映这些更改,并传递相应的batch_size

相关问题