如何在每个时期而不是每个批次中产生损失?

时间:2019-01-05 16:23:35

标签: python machine-learning keras generative-adversarial-network

据我所知,纪元是在整个数据集上任意重复地反复运行的过程,而数据集又是按部分进行处理的,即所谓的批处理。在计算每个train_on_batch的损失之后,将更新权重,下一批将获得更好的结果。这些损失是我学习神经网络的质量和学习状态的指标。

在几个来源中,每个时期都会计算(并打印)损失。因此,我不确定我是否做对了。

此刻,我的GAN看起来像这样:

for epoch:
  for batch:

    fakes = generator.predict_on_batch(batch)

    dlc = discriminator.train_on_batch(batch, ..)
    dlf = discriminator.train_on_batch(fakes, ..)
    dis_loss_total = 0.5 *  np.add(dlc, dlf)

    g_loss = gan.train_on_batch(batch,..)

    # save losses to array to work with later

这些损失是针对每个批次的。如何获得一个时代?顺便说一句:我需要一个时代的损失,为什么呢?

2 个答案:

答案 0 :(得分:0)

没有直接的方法来计算一个时期的损失。实际上,一个时期的损失通常定义为该时期中批次损失的平均值。因此,您可以累积一个时期的损失值,最后将其除以该时期的批次数量:

epoch_loss = []
for epoch in range(n_epochs):
    acc_loss = 0.
    for batch in range(n_batches):
        # do the training 
        loss = model.train_on_batch(...)
        acc_loss += loss
    epoch_loss.append(acc_loss / n_batches)

对于另一个问题,历元损失的一种用法可能是将其用作停止训练的指标(但是,验证损失通常用于此目的,而不是训练损失)。

答案 1 :(得分:0)

我将在@today回答中进行扩展。如何报告某个时期的损失以及如何使用它来确定何时应该停止训练,存在一定的平衡。

  • 如果仅查看最近一批的损失,那将是对数据集损失的非常嘈杂的估计,因为也许该批次恰好存储了模型遇到问题的所有样本,或者所有琐碎的样本成功。
  • 如果查看该时期所有批次的平均损失,您可能会得到偏斜的响应,因为正如您所指出的那样,该模型已经(希望)在整个时期得到了改善,因此初始批次的性能不是与以后批次的性能相比有意义。

准确报告时期损失的唯一方法是使模型退出训练模式,即修复所有模型参数,然后在整个数据集上运行模型。这将是您的历元损失的无偏计算。但是,总的来说,这是一个糟糕的主意,因为如果您具有复杂的模型或大量的训练数据,则将浪费大量时间。

因此,我认为最常见的是通过报告 N 个迷你批次的平均损失来平衡这些因素,其中 N 足以消除噪声。单独的批次,但没有那么大,以至于第一批和最后一批之间的模型性能无法媲美。

我知道您在Keras,但是here是一个PyTorch示例,可以清楚地说明此概念,并在此处复制:

for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

您可以看到他们累积了 N = 2000批次的损失,报告了这2000批次的平均损失,然后将运行损失归零并继续前进。