验证精度越来越低于训练精度

时间:2021-04-29 05:25:55

标签: deep-learning neural-network mobilenet overfitting-underfitting

这个模型还在训练中。并且验证准确度越来越低于训练。这表明过度拟合?我怎么能克服这个?我使用了 MobileNet 模型。我可以通过降低学习率来实现吗?

  Epoch 10/50
6539/6539 [==============================] - 3386s 518ms/step - loss: 0.8198 - accuracy: 0.7470 - top3_acc: 0.9199 - top5_acc: 0.9645 - val_loss: 1.0399 - val_accuracy: 0.6940 - val_top3_acc: 0.8842 - val_top5_acc: 0.9406
Epoch 11/50
6539/6539 [==============================] - 3377s 516ms/step - loss: 0.7939 - accuracy: 0.7558 - top3_acc: 0.9248 - top5_acc: 0.9669 - val_loss: 1.0379 - val_accuracy: 0.6953 - val_top3_acc: 0.8844 - val_top5_acc: 0.9411
Epoch 12/50
6539/6539 [==============================] - 3386s 518ms/step - loss: 0.7593 - accuracy: 0.7644 - top3_acc: 0.9304 - top5_acc: 0.9702 - val_loss: 1.0454 - val_accuracy: 0.6953 - val_top3_acc: 0.8831 - val_top5_acc: 0.9410
Epoch 13/50
6539/6539 [==============================] - 3394s 519ms/step - loss: 0.7365 - accuracy: 0.7735 - top3_acc: 0.9340 - top5_acc: 0.9713 - val_loss: 1.0476 - val_accuracy: 0.6938 - val_top3_acc: 0.8856 - val_top5_acc: 0.9411
Epoch 14/50
6539/6539 [==============================] - 3386s 518ms/step - loss: 0.7049 - accuracy: 0.7824 - top3_acc: 0.9387 - top5_acc: 0.9739 - val_loss: 1.0561 - val_accuracy: 0.6935 - val_top3_acc: 0.8841 - val_top5_acc: 0.9398
Epoch 15/50
6539/6539 [==============================] - 3390s 518ms/step - loss: 0.6801 - accuracy: 0.7901 - top3_acc: 0.9421 - top5_acc: 0.9755 - val_loss: 1.0673 - val_accuracy: 0.6923 - val_top3_acc: 0.8828 - val_top5_acc: 0.9391
Epoch 16/50
6539/6539 [==============================] - 3635s 556ms/step - loss: 0.6516 - accuracy: 0.7991 - top3_acc: 0.9462 - top5_acc: 0.9772 - val_loss: 1.0747 - val_accuracy: 0.6905 - val_top3_acc: 0.8825 - val_top5_acc: 0.9388
Epoch 17/50
6539/6539 [==============================] - 4070s 622ms/step - loss: 0.6200 - accuracy: 0.8082 - top3_acc: 0.9502 - top5_acc: 0.9805 - val_loss: 1.0859 - val_accuracy: 0.6883 - val_top3_acc: 0.8814 - val_top5_acc: 0.9373
Epoch 18/50
6539/6539 [==============================] - 4092s 626ms/step - loss: 0.5896 - accuracy: 0.8182 - top3_acc: 0.9550 - top5_acc: 0.9822 - val_loss: 1.1029 - val_accuracy: 0.6849 - val_top3_acc: 0.8788 - val_top5_acc: 0.9367
Epoch 19/50
6539/6539 [==============================] - 4087s 625ms/step - loss: 0.5595 - accuracy: 0.8291 - top3_acc: 0.9589 - top5_acc: 0.9834 - val_loss: 1.1147 - val_accuracy: 0.6872 - val_top3_acc: 0.8797 - val_top5_acc: 0.9367
Epoch 20/50
6539/6539 [==============================] - 4015s 614ms/step - loss: 0.5361 - accuracy: 0.8367 - top3_acc: 0.9617 - top5_acc: 0.9852 - val_loss: 1.1325 - val_accuracy: 0.6833 - val_top3_acc: 0.8773 - val_top5_acc: 0.9361
Epoch 21/50
6539/6539 [==============================] - 4093s 626ms/step - loss: 0.5023 - accuracy: 0.8472 - top3_acc: 0.9661 - top5_acc: 0.9870 - val_loss: 1.1484 - val_accuracy: 0.6844 - val_top3_acc: 0.8773 - val_top5_acc: 0.9363
Epoch 22/50
6539/6539 [==============================] - 4094s 626ms/step - loss: 0.4691 - accuracy: 0.8570 - top3_acc: 0.9703 - top5_acc: 0.9892 - val_loss: 1.1730 - val_accuracy: 0.6802 - val_top3_acc: 0.8765 - val_top5_acc: 0.9337
Epoch 23/50
6539/6539 [==============================] - 4091s 626ms/step - loss: 0.4387 - accuracy: 0.8676 - top3_acc: 0.9737 - top5_acc: 0.9904 - val_loss: 1.1986 - val_accuracy: 0.6774 - val_top3_acc: 0.8735 - val_top5_acc: 0.9320
Epoch 24/50
6539/6539 [==============================] - 4033s 617ms/step - loss: 0.4122 - accuracy: 0.8752 - top3_acc: 0.9764 - top5_acc: 0.9915 - val_loss: 1.2157 - val_accuracy: 0.6782 - val_top3_acc: 0.8755 - val_top5_acc: 0.9322
Epoch 25/50
6539/6539 [==============================] - 4105s 628ms/step - loss: 0.3838 - accuracy: 0.8861 - top3_acc: 0.9794 - top5_acc: 0.9927 - val_loss: 1.2419 - val_accuracy: 0.6746 - val_top3_acc: 0.8711 - val_top5_acc: 0.9309
Epoch 26/50
6539/6539 [==============================] - 4098s 627ms/step - loss: 0.3551 - accuracy: 0.8964 - top3_acc: 0.9824 - top5_acc: 0.9938 - val_loss: 1.2719 - val_accuracy: 0.6741 - val_top3_acc: 0.8722 - val_top5_acc: 0.9294
Epoch 27/50
6539/6539 [==============================] - 4101s 627ms/step - loss: 0.3266 - accuracy: 0.9051 - top3_acc: 0.9846 - top5_acc: 0.9950 - val_loss: 1.2877 - val_accuracy: 0.6723 - val_top3_acc: 0.8709 - val_top5_acc: 0.9288
Epoch 28/50
6539/6539 [==============================] - 4007s 613ms/step - loss: 0.3022 - accuracy: 0.9147 - top3_acc: 0.9866 - top5_acc: 0.9955 - val_loss: 1.3156 - val_accuracy: 0.6687 - val_top3_acc: 0.8667 - val_top5_acc: 0.9266
Epoch 29/50
6539/6539 [==============================] - 3410s 521ms/step - loss: 0.2797 - accuracy: 0.9208 - top3_acc: 0.9886 - top5_acc: 0.9962 - val_loss: 1.3409 - val_accuracy: 0.6712 - val_top3_acc: 0.8682 - val_top5_acc: 0.9270
Epoch 30/50
6539/6539 [==============================] - 3398s 520ms/step - loss: 0.2555 - accuracy: 0.9292 - top3_acc: 0.9907 - top5_acc: 0.9969 - val_loss: 1.3703 - val_accuracy: 0.6684 - val_top3_acc: 0.8661 - val_top5_acc: 0.9252
Epoch 31/50
6539/6539 [==============================] - 3401s 520ms/step - loss: 0.2365 - accuracy: 0.9358 - top3_acc: 0.9926 - top5_acc: 0.9975 - val_loss: 1.3945 - val_accuracy: 0.6660 - val_top3_acc: 0.8659 - val_top5_acc: 0.9270
Epoch 32/50
6539/6539 [==============================] - 3387s 518ms/step - loss: 0.2174 - accuracy: 0.9414 - top3_acc: 0.9934 - top5_acc: 0.9979 - val_loss: 1.4218 - val_accuracy: 0.6687 - val_top3_acc: 0.8650 - val_top5_acc: 0.9229
Epoch 33/50
6539/6539 [==============================] - 3397s 519ms/step - loss: 0.1986 - accuracy: 0.9478 - top3_acc: 0.9948 - top5_acc: 0.9983 - val_loss: 1.4513 - val_accuracy: 0.6641 - val_top3_acc: 0.8620 - val_top5_acc: 0.9217
Epoch 34/50
6539/6539 [==============================] - 3394s 519ms/step - loss: 0.1814 - accuracy: 0.9533 - top3_acc: 0.9956 - top5_acc: 0.9986 - val_loss: 1.4752 - val_accuracy: 0.6656 - val_top3_acc: 0.8612 - val_top5_acc: 0.9207

这是我的代码。我使用了 DeepFashion 数据集并使用 209222 张图像进行训练。并且还使用了 learning_rate=0.001 的 SGD 优化器。

mobile = tf.keras.applications.mobilenet.MobileNet(weights='imagenet')

x = mobile.layers[-6].input

if True:
    x = Reshape([7*7,1024])(x)
    att = MultiHeadsAttModel(l=7*7, d=1024 , dv=64, dout=1024, nv = 16 )
    x = att([x,x,x])
    x = Reshape([7,7,1024])(x)   
    x = BatchNormalization()(x)

x = mobile.get_layer('global_average_pooling2d')(x)
x = mobile.get_layer('reshape_1')(x)
x = mobile.get_layer('dropout')(x)
x = mobile.get_layer('conv_preds')(x)
x = mobile.get_layer('reshape_2')(x)
output = Dense(units=50, activation='softmax')(x)

model = Model(inputs=mobile.input, outputs=output)

1 个答案:

答案 0 :(得分:1)

您的验证损失会增加,而每次迭代中的训练损失往往会变小。这是 import java.awt.*; import java.awt.geom.*; import java.awt.image.BufferedImage; import javax.swing.*; import javax.swing.event.*; import javax.swing.border.EmptyBorder; public class TransformedShape { private JComponent ui = null; JLabel output = new JLabel(); JToolBar tools = new JToolBar("Tools"); ChangeListener changeListener = (ChangeEvent e) -> { refresh(); }; int pad = 5; Rectangle2D.Double rectangle = new Rectangle2D.Double(pad,pad,200,100); SpinnerNumberModel angleModel = new SpinnerNumberModel(30, 0, 90, 1); public TransformedShape() { initUI(); } private BufferedImage getImage() { int a = angleModel.getNumber().intValue(); AffineTransform rotateTransform = AffineTransform.getRotateInstance((a*2*Math.PI)/360d); Shape rotatedShape = rotateTransform.createTransformedShape(rectangle); Rectangle2D rotatedRect = rotatedShape.getBounds2D(); double xOff = rectangle.getX()-rotatedRect.getX(); double yOff = rectangle.getY()-rotatedRect.getY(); AffineTransform translateTransform = AffineTransform.getTranslateInstance(xOff, yOff); Shape rotateAndTranslateShape = translateTransform.createTransformedShape(rotatedShape); Area combinedShape = new Area(rotateAndTranslateShape); combinedShape.add(new Area(rectangle)); Rectangle2D r = combinedShape.getBounds2D(); BufferedImage bi = new BufferedImage((int)(r.getWidth()+(2*pad)), (int)(r.getHeight()+(2*pad)), BufferedImage.TYPE_INT_ARGB); Graphics2D g = bi.createGraphics(); g.setRenderingHint(RenderingHints.KEY_ALPHA_INTERPOLATION, RenderingHints.VALUE_ALPHA_INTERPOLATION_QUALITY); g.setRenderingHint(RenderingHints.KEY_COLOR_RENDERING, RenderingHints.VALUE_COLOR_RENDER_QUALITY); g.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON); g.setColor(new Color(255,0,0,127)); g.fill(rectangle); g.setColor(new Color(0,0,255,127)); g.fill(rotateAndTranslateShape); g.dispose(); return bi; } private void addModelToToolbar(String label, SpinnerNumberModel model) { tools.add(new JLabel(label)); JSpinner spinner = new JSpinner(model); spinner.addChangeListener(changeListener); tools.add(spinner); } public final void initUI() { if (ui!=null) return; ui = new JPanel(new BorderLayout(4,4)); ui.setBorder(new EmptyBorder(4,4,4,4)); ui.add(output); ui.add(tools,BorderLayout.PAGE_START); addModelToToolbar("Angle", angleModel); refresh(); } private void refresh() { output.setIcon(new ImageIcon(getImage())); } public JComponent getUI() { return ui; } public static void main(String[] args) { Runnable r = () -> { try { UIManager.setLookAndFeel(UIManager.getSystemLookAndFeelClassName()); } catch (Exception ex) { ex.printStackTrace(); } TransformedShape o = new TransformedShape(); JFrame f = new JFrame(o.getClass().getSimpleName()); f.setDefaultCloseOperation(JFrame.DISPOSE_ON_CLOSE); f.setLocationByPlatform(true); f.setContentPane(o.getUI()); f.pack(); f.setMinimumSize(f.getSize()); f.setVisible(true); }; SwingUtilities.invokeLater(r); } } 的经典案例。

我不熟悉“MobileNet 模型”,但如果您分享架构或架构详细信息的链接会有所帮助。

我可以盲目地建议添加 dropouts ( https://www.tensorflow.org/api_docs/python/tf/keras/layers/Dropout) 以规范您的模型(我猜您模型中没有 dropouts)。老实说,我看不出改变“学习率”如何有助于克服过度拟合,所以我不建议这样做。

由于您没有分享有关数据集大小的任何信息,我不确定数据集的规模和多样性。无论如何,如果数据集比较小,你可以扩充你的数据集,你有更好的机会减少过拟合。