Tensorflow:自定义CNN模型的准确性因训练=虚假

时间:2020-04-24 06:56:14

标签: python tensorflow keras tensorflow2.x

我已经训练了一种如下所示的自定义CNN模型:

class MyModel(tf.keras.Model):
    def __init__(self):
        self.layers = [Conv2D(filters=4,kernel_size=(12,12),input_shape=28,28,1),padding='same',activation='relu'),\
            BatchNormalization(axis=-1 if tf.keras.backend.image_data_format() == 'channels_last' else 1),\
            Conv2D(filters=16,kernel_size=(2,2),padding='same',activation='sigmoid'),\
            BatchNormalization(axis=-1 if tf.keras.backend.image_data_format() == 'channels_last' else 1),\
            Conv2D(filters=1,kernel_size=(3,3),padding='same',activation='relu'),\
            Conv2D(28, (9, 9), input_shape=(28,28,2), activation='sigmoid'), \
            BatchNormalization(axis=-1 if tf.keras.backend.image_data_format() == 'channels_last' else 1), Conv2D(28, (3, 3), activation='sigmoid'), \
            Conv2D(14, (3, 3), activation='relu'), BatchNormalization(axis=-1 if tf.keras.backend.image_data_format() == 'channels_last' else 1), \
            MaxPooling2D((2, 2)), Conv2D(56, (3, 3), activation='relu'), BatchNormalization(axis=-1 if tf.keras.backend.image_data_format() == 'channels_last' else 1), \
            Conv2D(56, (3, 3), activation='relu'), MaxPooling2D((2, 2)), Flatten(),\
            Dense(200, activation="relu"), Dense(100, activation="relu"), Dense(10, activation='softmax', name='class_output')]

    def call(self, inputs, training=False):
        x = self.layers[0](inputs, training=training)
        for lyr in range(1,len(self.layers)):
            x = lyr(x, training=training)

有时(并非总是如此!)将training变量设置为False时,测试的准确性下降到接近于0。

如果我用类似的东西测试训练后的模型:

def test_model(model, test_dataset, training=False):
    accuracy = tf.keras.metrics.Accuracy(name='Test Accuracy')
    for (x, y) in test_dataset:
        test_preds = model(x, training=training) 
        accuracy.update_state(y, tf.argmax(test_preds, 1))
    print("{}: {}\n".format(accuracy.name, accuracy.result()*100))

例如,在训练=错误的情况下我得到0.018%,在训练=正确的情况下我得到98.9%。任何可能导致此行为的想法?批处理规范化本身似乎并不固定,因为我第一次在Dropout层而不是批处理规范化层经历了类似的行为。

0 个答案:

没有答案