Keras回调定制时代循环

时间:2017-05-04 09:48:08

标签: python tensorflow keras

我用keras训练LSTM。输入序列具有不同的长度。假设序列的长度介于1和num_seq之间。因此,我按照每个时期的长度对序列进行分组,以便使用批量大小> 1:

for epoch in xrange(nb_epochs):
 for i in range(1,num_seq):
  X,y = get_sequences(length=i)
  model.fit(X,y,batch_size=100,epochs=1, validation_split=0.1, callbacks=None)

因为我在纪元上使用自定义循环,所以使用纪元信息的回调不能正常工作(例如张量板,历史记录等)。有什么方法可以解决这个问题?有没有办法告诉fit函数,它目前在哪个时代?

1 个答案:

答案 0 :(得分:2)

在训练期间对训练数据进行操作时,应该逐步使用model.train_on_batch或者 - 更好 - 使用fit_generator,这样可以定义一个python生成器,为每个批处理生成(x,y)个元组。这样就可以正确调用回调函数。

例如:

def train_gen():
   while True:
       for i in range(1,num_seq):
           X,y = get_sequences(length=i)
           yield X, y
model.fit_generator(train_gen, steps_per_epoch=num_seq)

这样做的缺点是您必须自己进行批处理,并且还必须自己提供验证分区,您也可以使用生成器(因此可以重用大部分代码)。