我采用了标准的ResNet50模型:
model = keras.applications.resnet50.ResNet50(include_top=False,
weights='imagenet',
classes=10,
input_shape=(224, 224, 3))
并添加了我自己的几个密集层:
top_model = Sequential()
top_model.add(Flatten(input_shape=model.output_shape[1:]))
top_model.add(Dense(256, activation='relu'))
top_model.add(Dropout(0.5))
top_model.add(Dense(2, activation='softmax'))
model = Model(input=model.input, output=top_model(model.output))
这种方式效果很好,但是,当我想删除带有model.pop()
keras的最后一个Dense和Dropout图层时效果不好:
model.layers[-1].layers
[<keras.layers.core.Flatten at 0x16b5c00b8>,
<keras.layers.core.Dense at 0x16b5c0320>,
<keras.layers.core.Dropout at 0x16b5c02e8>,
<keras.layers.core.Dense at 0x16b5c0d68>]
model.layers[-1].pop()
model.layers[-1].pop()
model.layers[-1].layers
[<keras.layers.core.Flatten at 0x1ae6e5940>,
<keras.layers.core.Dense at 0x1ae6e9e10>]
model.layers[-1].outputs = [model.layers[-1].layers[-1].output]
model.outputs = model.layers[-1].outputs
model.layers[-1].layers[-1].outbound_nodes = []
然后我只是编译模型,在尝试预测时,我收到一个错误:
You must feed a value for placeholder tensor 'flatten_7_input_12' with dtype float and shape [?,1,1,2048]
答案 0 :(得分:0)
model.pop()
负责所有基础设置,包括将model.output
设置为新最后一层的输出。因此,您不需要处理有关输出的任何事情
另请注意,您要分配model
变量;因此,model.outputs
已经指的是扩展模型的输出。
以下示例代码在keras 2.0.6上使用TensorFlow后端(1.4.0)正常工作:
import keras
from keras.models import Sequential, Model
from keras.layers import *
import numpy as np
model = keras.applications.resnet50.ResNet50(include_top=False,
weights='imagenet',
classes=10,
input_shape=(224, 224, 3))
top_model = Sequential()
top_model.add(keras.layers.Flatten(input_shape=model.output_shape[1:]))
top_model.add(keras.layers.Dense(256, activation='relu'))
top_model.add(keras.layers.Dropout(0.5))
top_model.add(keras.layers.Dense(2, activation='softmax'))
model_extended = Model(input=model.input, output=top_model(model.output))
model_extended.layers[-1].pop()
model_extended.layers[-1].pop()
model_extended.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
model_extended.predict(np.zeros((1, 224, 224, 3)))