feed字典命令在训练循环中的作用

时间:2018-05-22 14:38:37

标签: python tensorflow

我正在努力了解训练循环。在计算训练和测试准确度时,我们会通过训练和测试集替换xy_,但在打印交叉熵的结果时,为什么我们要xy_ {{1}分别和batch_xs?我知道我们必须为占位符指定一个值,但在计算测试准确度时batch_ysbatch_xs保持不变,并且两者都使用列车集值。

batch_ys

1 个答案:

答案 0 :(得分:0)

您可以在进行训练的同时计算您的训练损失(这样您就不必再单独打电话,稍后再次通过您的训练),例如:

for j in range(nSteps):
    # ...
    _, train_loss = sess.run([train_step, cross_entropy],
                             feed_dict={x: batch_xs[i], y_: batch_ys[i]})

至于为什么在计算准确度时不需要调整训练/测试集的大小,让我们在整个代码中查看数据的维度。你从:

开始
x = tf.placeholder(tf.float32, [None, nPixels])
y_ = tf.placeholder(tf.float32, [None, nLabels])

这意味着x的维度为[num_samples, nPixels]y_的维度为[num_samples, nLabels]。我认为x_trainy_train只是其中的一些子样本,因此分别为[num_train, nPixels][num_train, nLabels]。接下来,为您的批次重塑这些:

batch_xs = np.reshape(x_train,(nSteps,bSize,nPixels))
batch_ys = np.reshape(y_train,(nSteps,bSize,nLabels))

现在batch_xs的维度为[num_steps, batch_size, nPixels]batch_ys的维度为[num_steps, batch_size, nLabels]。请注意,最后尺寸,要素数量或输出尺寸未发生变化。

现在,对于训练循环的每次迭代,您可以使用batch_xs[i]batch_ys[i]从这些列表的第一维中获取一个元素。这些人的维度现在分别为[batch_size, nPixels][batch_size, nLables]。同样,最后一个维度没有改变。

最后,您可以批量或整个训练集调用accuracycross_entropy操作,TF不关心它是什么!这是因为您提供的功能/标签数量无论哪种方式都相同,所有改变的是您传递的数据元素的数量。在一种情况下,您传递batch_size个元素,而在另一个案例中,您传递的是num_trainnum_test个元素,但最后一个维度指定了每个样本的外观是一样的! TF很神奇,并想出如何处理差异。

旧答案

看起来你只是试图让你的循环中的训练损失,因此你将训练批次传递给cross_entropy函数。很多时候,您可能会在每个批次中对此进行评估并存储结果,以便您可以在以后随时间绘制训练损失,但有时这种计算需要足够长的时间,以至于您只需要经常计算它。您的代码似乎就是这种情况。

通常,在训练循环运行时,最好同时查看训练测试(或验证)丢失。这是因为您可以更好地看到过度拟合的影响,其中训练损失继续减少并且测试损失水平关闭或开始增加。类似的东西:

if j % 100 == 0:
    train_loss = sess.run(cross_entropy, feed_dict={x: batch_xs[i], y_: batch_ys[i]})
    test_loss = sess.run(cross_entropy, feed_dict={x: x_test, y_: y_test})
    # print or store your training and test loss

监控测试损失的主要原因之一是在过度装配时尽早停止培训。您可以通过存储最佳测试损失,然后将每个新的测试损失与您存储的最佳测试损失进行比较来实现。每当你得到一个没有达到最佳效果的新测试损失时,你增加一个计数器,当计数器达到一定数量(比如10或20)时,你就会杀死训练循环并称之为好。通过这种方式,当您的训练停止在验证/测试集上时,您就会发现这是一个更好的指示,它可以更好地推广到新数据。

如果您遇到的新验证损失优于最佳验证,则将其存储并重置计数器。

最后,通常最好保留两组数据,一组称为验证组,另一组称为测试组。当您将其拆分为这样时,通常在训练时使用验证集来查找过度拟合并执行提前停止,然后针对网络从未见过的测试集执行最后一次测试,以确保准确性/丢失了解其整体概括性能。这看起来像是:

for epoch:
    for batch in batches:
        x = batch.x
        y = batch.y
        sess.run(train_step, feed_dict={x: x, y_: y})
        if j % 100 == 0:
            train_loss = sess.run(cross_entropy, feed_dict={x: x, y_: y})
            test_loss = sess.run(cross_entropy, feed_dict={x: x_val, y_: y_val})  # note validation data
            # print or store your training and test loss
            # check for early stopping
test_loss = sess.run(cross_entropy, feed_dict={x: x_test, y_: y_test})
# Or test accuracy or whatever you want to evaluate.
# Point is the network never saw this data until now.