更改底层后,网络很容易变得过拟合

时间:2019-07-16 21:18:25

标签: deep-learning computer-vision classification

我正在尝试使用Inception-v3训练CNN网络以对4种Crystal图像进行分类。因为图像包含许多细节,这些细节肯定可以帮助分类过程,所以我对网络进行了微调,并在网络底部添加了一些卷积层,其中一个跨度为2,因此它可以接收比原始的Inception-v3网络。但是,在添加它们之后,该网络变得更容易过度拟合,并且比原始网络更糟糕。

我在原始的Inception-v3网络的底部添加了16个卷积层,以便它可以接收599 * 599图像。但是网络很容易变得过拟合。我尝试增加权重衰减,但结果表明准确度甚至比原始网络还差,收敛速度很慢。

    base_model = tf.keras.models.Sequential([tf.keras.layers.Conv2D(filters=3, kernel_size=3, strides=1,
                                                                    padding="same", input_shape=(599, 599, 3)),   
                                     tf.keras.layers.BatchNormalization(),  
                                     tf.keras.layers.Activation("relu"),
                                     tf.keras.layers.Conv2D(filters=3, kernel_size=3, strides=1, 
                                                             padding="same"),   
                                     tf.keras.layers.BatchNormalization(), 
                                     tf.keras.layers.Activation("relu"),
                                     tf.keras.layers.Conv2D(filters=3, kernel_size=3, strides=1,
                                                             padding="same"), 
                                     tf.keras.layers.BatchNormalization(), 
                                     tf.keras.layers.Activation("relu"),
                                     tf.keras.layers.Conv2D(filters=3, kernel_size=3, strides=1, 
                                                             padding="same"),
                                     tf.keras.layers.BatchNormalization(), 
                                     tf.keras.layers.Activation("relu"),
                                     tf.keras.layers.Conv2D(filters=3, kernel_size=3, strides=1, 
                                                             padding="same"),
                                     tf.keras.layers.BatchNormalization(), 
                                     tf.keras.layers.Activation("relu"),
                                     tf.keras.layers.Conv2D(filters=3, kernel_size=3, strides=1, 
                                                             padding="same"),
                                     tf.keras.layers.BatchNormalization(), 
                                     tf.keras.layers.Activation("relu"),
                                     tf.keras.layers.Conv2D(filters=3, kernel_size=3, strides=1,
                                                             padding="same"),
                                     tf.keras.layers.BatchNormalization(), 
                                     tf.keras.layers.Activation("relu"),
                                     tf.keras.layers.Conv2D(filters=3, kernel_size=3, strides=1,
                                                             padding="same"),
                                     tf.keras.layers.BatchNormalization(), 
                                     tf.keras.layers.Activation("relu"),
                                     tf.keras.layers.Conv2D(filters=3, kernel_size=3, strides=1,
                                                             padding="same"),
                                     tf.keras.layers.BatchNormalization(), 
                                     tf.keras.layers.Activation("relu"),
                                     tf.keras.layers.Conv2D(filters=3, kernel_size=3, strides=1,
                                                            padding="same"),
                                     tf.keras.layers.BatchNormalization(), 
                                     tf.keras.layers.Activation("relu"),
                                     tf.keras.layers.Conv2D(filters=3, kernel_size=3, strides=1, 
                                                            padding="same"),
                                     tf.keras.layers.BatchNormalization(), 
                                     tf.keras.layers.Activation("relu"),
                                     tf.keras.layers.Conv2D(filters=3, kernel_size=3, strides=1, 
                                                            padding="same"),
                                     tf.keras.layers.BatchNormalization(), 
                                     tf.keras.layers.Activation("relu"),
                                     tf.keras.layers.Conv2D(filters=3, kernel_size=3, strides=1, 
                                                            padding="same"),
                                     tf.keras.layers.BatchNormalization(), 
                                     tf.keras.layers.Activation("relu"),
                                     tf.keras.layers.Conv2D(filters=3, kernel_size=3, strides=1, 
                                                            padding="same"),
                                     tf.keras.layers.BatchNormalization(), 
                                     tf.keras.layers.Activation("relu"),
                                     tf.keras.layers.Conv2D(filters=3, kernel_size=3, strides=1, 
                                                            padding="same"),
                                     tf.keras.layers.BatchNormalization(), 
                                     tf.keras.layers.Activation("relu"),
                                     tf.keras.layers.Conv2D(filters=3, kernel_size=3, strides=2,
                                                            padding="same"),
                                     tf.keras.layers.BatchNormalization(), 
                                     tf.keras.layers.Activation("relu")])

    inception_model = tf.keras.applications.inception_v3.InceptionV3(include_top=True, weights=None, 
                                                                     input_tensor=base_model.output, 
                                                                     input_shape=None,
                                                                     pooling=None, classes=4)

    model = tf.keras.models.Model(inputs=base_model.input, outputs=inception_model.output)

原始的Inception-v3可以得到约88%的精度而不会衰减,而使用具有0.3衰减的较大的输入图像后,网络只能获得约70%的精度。

0 个答案:

没有答案