train_on_batch()与fit()有何不同?什么情况下我们应该使用train_on_batch()?
答案 0 :(得分:21)
我认为您的意思是将train_on_batch
与fit
(以及fit_generator
之类的变体)进行比较,因为train
不是Keras常用的API函数。
对于这个问题,它是simple answer from the primary author:
使用fit_generator,您可以使用生成器作为验证数据 好。一般来说,我建议使用fit_generator,但使用 train_on_batch也可以。这些方法只是为了存在而存在 在不同用例中的便利性,没有"正确"方法
train_on_batch
允许您根据您提供的样本集合明确更新权重,而不考虑任何固定的批量大小。如果你想要的话,你可以使用它:训练一个明确的样本集合。您可以使用该方法在传统训练集的多个批处理中维护自己的迭代,但允许fit
或fit_generator
为您重复批处理可能更简单。
使用train_on_batch
可能更好的一种情况是更新一批新样本的预训练模型。假设您已经训练并部署了模型,并且有一段时间后您已经收到了以前从未使用过的一组新的训练样本。您可以使用train_on_batch
仅在这些示例上直接更新现有模型。其他方法也可以这样做,但在这种情况下使用train_on_batch
是明确的。
除了这样的特殊情况(要么你有一些教学理由在不同的培训批次中保持自己的光标,或者对于特殊批次的某种类型的半在线培训更新),它可能更好始终使用fit
(适用于内存中的数据)或fit_generator
(用于将批量数据作为生成器进行流式传输)。
答案 1 :(得分:9)
train_on_batch()
使您可以更好地控制LSTM的状态,例如,在使用有状态LSTM且需要控制对model.reset_states()
的调用时。您可能具有多系列数据,并且需要在每个系列之后重置状态,您可以使用train_on_batch()
进行此操作,但是如果使用.fit()
,则网络将接受所有系列数据的训练,而无需重置状态。没有对与错,这取决于您使用的是什么数据,以及您希望网络如何运行。
答案 2 :(得分:1)
如果您使用大型数据集并且没有容易序列化的数据(如高阶numpy数组)写入tfrecords,Train_on_batch的性能也会比fit和fit生成器提高。
在这种情况下,当整个数组都不适合内存时,可以将数组另存为numpy文件,并在内存中加载它们的较小子集(traina.npy,trainb.npy等)。然后,您可以使用tf.data.Dataset.from_tensor_slices,然后对子数据集使用train_on_batch,然后加载另一个数据集并再次批量调用train等。现在,您已经对整个集合进行了训练,并且可以精确控制数量和内容数据集训练模型。然后,您可以使用简单的循环和函数来定义自己的纪元,批处理大小等,以从数据集中获取数据。
答案 3 :(得分:1)
确实,@ nbro答案很有帮助,只是增加了一些场景,可以说您正在训练一些seq到seq模型或具有一个或多个编码器的大型网络。我们可以使用train_on_batch创建自定义训练循环,并使用我们的部分数据直接在编码器上进行验证,而无需使用回调。为复杂的验证过程编写回调可能很困难。在某些情况下,我们希望进行批量培训。
关于, 卡尔提克