在Keras for AlexNet培训之前如何加载imagenet权重?

时间:2019-11-07 10:51:21

标签: python keras deep-learning conv-neural-network sequential

嗨,我使用顺序方法在keras中编写了AlexNet。我想知道是否以及如何加载imagenet权重来训练模型?

目前,我正在为每一层使用randomNormal内核初始化。但是我想使用imagenet权重进行训练。我有作为H5文件的权重。有人也可以提供示例代码吗?

2 个答案:

答案 0 :(得分:1)

model = Sequential()

# 1st Convolutional Layer
model.add(Conv2D(filters=96, input_shape=(224,224,3), kernel_size=(11,11), strides=(4,4), padding=’valid’))
model.add(Activation(‘relu’))
# Max Pooling
model.add(MaxPooling2D(pool_size=(2,2), strides=(2,2), padding=’valid’))

# 2nd Convolutional Layer
model.add(Conv2D(filters=256, kernel_size=(11,11), strides=(1,1), padding=’valid’))
model.add(Activation(‘relu’))
# Max Pooling
model.add(MaxPooling2D(pool_size=(2,2), strides=(2,2), padding=’valid’))

# 3rd Convolutional Layer
model.add(Conv2D(filters=384, kernel_size=(3,3), strides=(1,1), padding=’valid’))
model.add(Activation(‘relu’))

# 4th Convolutional Layer
model.add(Conv2D(filters=384, kernel_size=(3,3), strides=(1,1), padding=’valid’))
model.add(Activation(‘relu’))

# 5th Convolutional Layer
model.add(Conv2D(filters=256, kernel_size=(3,3), strides=(1,1), padding=’valid’))
model.add(Activation(‘relu’))
# Max Pooling
model.add(MaxPooling2D(pool_size=(2,2), strides=(2,2), padding=’valid’))

# Passing it to a Fully Connected layer
model.add(Flatten())
# 1st Fully Connected Layer
model.add(Dense(4096, input_shape=(224*224*3,)))
model.add(Activation(‘relu’))
# Add Dropout to prevent overfitting
model.add(Dropout(0.4))

# 2nd Fully Connected Layer
model.add(Dense(4096))
model.add(Activation(‘relu’))
# Add Dropout
model.add(Dropout(0.4))

# 3rd Fully Connected Layer
model.add(Dense(1000))
model.add(Activation(‘relu’))
# Add Dropout
model.add(Dropout(0.4))

# Output Layer
model.add(Dense(17))
model.add(Activation(‘softmax’))

model.summary()

# Compile the model
model.compile(loss=keras.losses.categorical_crossentropy, optimizer=’adam’, metrics=[“accuracy”])

model.load_weights('weight.h5')

答案 1 :(得分:0)

由于您是用keras编写AlexNet的,并且权重为H5文件,因此可以将h5文件中的权重还原为Keras模型。

model.load_weights('my_model_weights.h5')