所以我试图使用Keras' fit_generator使用自定义数据生成器提供给LSTM网络。
为了说明问题,我创建了一个玩具示例,尝试以简单的升序序列预测下一个数字,并使用Keras TimeseriesGenerator创建一个Sequence实例:
WINDOW_LENGTH = 4
data = np.arange(0,100).reshape(-1,1)
data_gen = TimeseriesGenerator(data, data, length=WINDOW_LENGTH,
sampling_rate=1, batch_size=1)
我使用简单的LSTM网络:
data_dim = 1
input1 = Input(shape=(WINDOW_LENGTH, data_dim))
lstm1 = LSTM(100)(input1)
hidden = Dense(20, activation='relu')(lstm1)
output = Dense(data_dim, activation='linear')(hidden)
model = Model(inputs=input1, outputs=output)
model.compile(loss='mse', optimizer='rmsprop', metrics=['accuracy'])
并使用fit_generator
函数训练它:
model.fit_generator(generator=data_gen,
steps_per_epoch=32,
epochs=10)
这完美训练,模型按预期进行预测。
现在的问题是,在我的非玩具情况下,我想在将数据输入fit_generator
之前处理来自TimeseriesGenerator的数据。作为向前迈出的一步,我创建了一个生成器函数,它只包含之前使用过的TimeseriesGenerator。
def get_generator(data, targets, window_length = 5, batch_size = 32):
while True:
data_gen = TimeseriesGenerator(data, targets, length=window_length,
sampling_rate=1, batch_size=batch_size)
for i in range(len(data_gen)):
x, y = data_gen[i]
yield x, y
data_gen_custom = get_generator(data, data,
window_length=WINDOW_LENGTH, batch_size=1)
但现在奇怪的是,当我像以前一样训练模型,但是使用这个生成器作为输入时,
model.fit_generator(generator=data_gen_custom,
steps_per_epoch=32,
epochs=10)
没有错误,但训练错误遍布整个地方(上下跳跃而不是像其他方法那样持续下降),并且模型没有学会做出好的预测。 / p>
我的自定义生成器方法出了什么问题?
答案 0 :(得分:3)
可能是因为对象类型从Sequence
更改为TimeseriesGenerator
到通用生成器。 fit_generator
函数以不同方式处理这些问题。更简洁的解决方案是继承类并覆盖处理位:
class CustomGen(TimeseriesGenerator):
def __getitem__(self, idx):
x, y = super()[idx]
# do processing here
return x, y
并像以前一样使用这个类,因为其余的内部逻辑将保持不变。
答案 1 :(得分:0)
我个人对nuric代码有疑问。由于某种原因,我有一个超级错误,无法编写脚本。这是我可能的解决方法。让我知道这是否可行吗?
class CustomGen(TimeseriesGenerator):
def __getitem__(self, idx):
x,y = super().__getitem__(idx)
return x, y