BatchNormalization炸毁了Keras模型

时间:2018-06-22 13:43:31

标签: tensorflow keras deep-learning batch-normalization

我尝试使用以下代码在以tensorflow作为后端的keras上训练模型:

CHANNEL_AXIS = 3

img_width, img_height = 513, 128
nb_classes = 10
batch_size = 64

input_shape = (img_width, img_height, 3)
inputs = layers.Input(input_shape)

tempModel = layers.Conv2D(filters = 256, kernel_size=(4, 513), strides=(1, 1), padding='same')(inputs)

shortcut = tempModel

tempModel = layers.BatchNormalization(axis = CHANNEL_AXIS)(tempModel)
tempModel = layers.Activation('relu')(tempModel)
tempModel = layers.Conv2D(filters = 256, kernel_size=(4, 1), strides=(1, 1), padding='same', activation=None)(tempModel)

tempModel = layers.BatchNormalization(axis = CHANNEL_AXIS)(tempModel)
tempModel = layers.Activation('relu')(tempModel)
tempModel = layers.Conv2D(filters = 256, kernel_size=(4, 1), strides=(1, 1), padding='same', activation=None)(tempModel)

tempModel = layers.add([shortcut, tempModel])

max_p_layer = layers.GlobalMaxPooling2D(data_format='channels_last')(tempModel)
avg_p_layer = layers.GlobalAveragePooling2D(data_format='channels_last')(tempModel)

tempModel = layers.concatenate([max_p_layer, avg_p_layer])
tempModel = layers.Dense(300, activation='relu')(tempModel)
tempModel = layers.Dropout(0.2)(tempModel)
tempModel = layers.Dense(150, activation='relu')(tempModel)
tempModel = layers.Dropout(0.2)(tempModel)
tempModel = layers.Dense(nb_classes, activation='softmax')(tempModel)

model = Model(inputs=inputs, outputs=tempModel)

现在,当我尝试训练模型时,训练非常缓慢,尤其是与具有更多参数的其他体系结构相比。此外,该模型需要更多的内存(不计权重,总共超过30 GB),我认为这是由于BatchNormalization Layers(至少在模型出于测试目的而删除时,该模型使用的GB减少了几GB)。是我错误地实施了网络,还是BatchNormalization层非常慢?

0 个答案:

没有答案