我使用经过预先培训的VGG16网络跟踪Building powerful image classification models using very little data的教程。
我能够获得相当准确的数据集:)但是,我如何重用经过重新训练的网络?
model = Sequential()
model.add(Flatten(input_shape=train_data.shape[1:]))
model.add(Dense(256, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(nb_class, activation='softmax'))
model.compile(optimizer='rmsprop',
loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(train_data, train_labels,
epochs=epochs,
batch_size=batch_size,
validation_data=(validation_data, validation_labels))
# After training, I save my model
model_json = model.to_json()
with open(top_model_json_path, "w") as json_file:
json_file.write(model_json)
model.save_weights(top_model_weights_path)
我希望将dog.1234.jpg
重新用于def main(_):
# vgg16
vgg_model = applications.VGG16(include_top=False, weights='imagenet')
# load top model and weights
json_file = open('retrain_VGG16_dogcat2000.json', 'r')
loaded_model_json = json_file.read()
json_file.close()
top_model = model_from_json(loaded_model_json)
top_model.load_weights("retrain_VGG16_dogcat2000.h5")
# concatenate vgg16 and top classifier
model = Model(inputs=vgg_model.input, outputs=top_model(vgg_model.output)) # error occurs
# test
image = cv2.imread(FLAGS.img_dir)
image = cv2.resize(image, (224, 224)).astype(np.float32) / 255.
image = image.reshape(1 , image.shape[0], image.shape[1], image.shape[2])
model.predict(image)
The shape of the input to "Flatten" is not fully defined (got (None, None, 512). Make sure to pass a complete "input_shape" or "batch_input_shape" argument to the first layer in your model.
发生错误:{{1}}
有什么想法吗?
答案 0 :(得分:0)
您的VGG模型没有明确定义的输入形状。它接受任何大小的图像。由于它是纯粹的卷积模型,因此输出形状也是可变的。这就是形状显示为Kernel.php
的原因。
但是(None, None, 512)
图层(存在于您自己的Flatten
中)不支持变量输入形状。它必须使用定义明确的一个。
您有两种可能的解决方案:
top_model
定义VGG模型的输入形状(这不需要您重新训练此模型,权重不依赖于纯卷积模型中图像的大小) input_shape=(224,224,3)
。 (这可能会降低你的模型的能力,而且肯定需要重新训练一切)答案 1 :(得分:0)
我目前能够以这种方式重复使用我的再训练模型。但是,如何将重新训练的模型连接到一个机器学习模型中呢?
classes = ['cat', 'dog']
# load VGG model
vgg_model = applications.VGG16(include_top=False, weights='imagenet')
# load model and weights
json_file = open('retrain_VGG16_dogcat2000.json', 'r')
loaded_model_json = json_file.read()
json_file.close()
top_model = model_from_json(loaded_model_json)
top_model.load_weights("retrain_VGG16_dogcat2000.h5")
# test image
image = cv2.imread(FLAGS.img_dir)
image = cv2.resize(image, (224, 224)).astype(np.float32) / 255.
image = image.reshape(1 , image.shape[0], image.shape[1], image.shape[2])
# predict the image throughout vgg net
vgg_out = vgg_model.predict(image)
# predict the top_model with input from vgg's output
prediction = top_model.predict(vgg_out)
print(prediction, classes[np.argmax(prediction)])
python evaluate.py --img_dir ~/datasets/dog_cat/dog.4457.jpg
[[5.281641e-32 1.000000e + 00]]狗
python evaluate.py --img_dir ~/datasets/dog_cat/cat.2345.jpg
[[1.000000e + 00 9.587834e-36]] cat