train_on_batch()在keras模型中做了什么?

时间:2018-01-31 19:41:02

标签: python tensorflow machine-learning keras artificial-intelligence

所以我看到了一个代码示例(太大而无法粘贴),作者使用model.train_on_batch(in, out)代替model.fit(in, out)。 keras的官方文件说:

  

对一批样品进行单一梯度更新。

但我不明白。它是否与fit()相同,但它没有做很多前馈 - backprop,它只执行一次?还是我错了?请解释一下。

3 个答案:

答案 0 :(得分:7)

是的,train_on_batch只使用一个批次和一次。

虽然fit训练许多时期的许多批次。 (每个批次都会导致权重更新)。

使用train_on_batch的想法可能是在每批之间自己做更多事情。

答案 1 :(得分:0)

用于每次批次培训后我们希望了解并进行一些自定义更改时使用。

一个更精确的用例是GAN。 您必须更新鉴别器,但在更新GAN网络的过程中,必须使鉴别器不易训练。因此,您首先要训练鉴别器,然后再训练gan使鉴别器不可训练。 请参阅以下内容以了解更多信息: https://medium.com/datadriveninvestor/generative-adversarial-network-gan-using-keras-ce1c05cfdfd3

答案 2 :(得分:0)

模型的方法拟合将训练模型一次通过您提供的数据,但是由于内存(尤其是GPU内存)的限制,我们无法一次训练大量样本,因此我们需要将这些数据分成称为迷你批次的小片段(或成批)。 keras模型的方法拟合将为您完成此数据划分,并遍历您提供的所有数据。

但是,有时我们需要更复杂的训练程序,例如,我们希望在每个时期随机选择新样本放入批处理缓冲区中(例如GAN训练和Siamese CNN训练...),在这种情况下,我们不使用看中了一个简单的拟合方法,但是我们使用了train_on_batch方法。要使用此方法,我们在每次迭代中生成一批输入和一批输出(标签),并将其传递给此方法,它将立即对一批中的整个样本进行训练,并给出损失和其他指标计算批次样品。