为了使用VGG16网络进行回归任务,我按以下方式扩展它:
keras.applications.vgg16.VGG16(input_tensor=input_tensor, include_top=False)
x = model.output
x = Flatten()(x)
x = Dense(1024, activation='relu')(x)
x = Dropout(0.5)(x)
x = Dense(512,activation='relu')(x)
x = Dropout(0.5)(x)
x = Dense(256,activation='relu')(x)
x = Dropout(0.5)(x)
x = Dense(1)(x)
model = Model(model.input, x)
立即抛出以下异常:
File "C:\Users\Ralph\Documents\GitHub\CarND-Behavioral-Cloning-P3\model.py", line 116, in <module>
tf.app.run()
File "C:\Users\Ralph\Miniconda3\envs\carnd-term1\lib\site-packages\tensorflow\python\platform\app.py", line 43, in run
sys.exit(main(sys.argv[:1] + flags_passthrough))
File "C:\Users\Ralph\Documents\GitHub\CarND-Behavioral-Cloning-P3\model.py", line 101, in main
model = create_model()
File "C:\Users\Ralph\Documents\GitHub\CarND-Behavioral-Cloning-P3\model.py", line 60, in create_model
x = Dense(1024)(x)
File "C:\Users\Ralph\Miniconda3\envs\carnd-term1\lib\site-packages\keras\engine\topology.py", line 529, in __call__
self.assert_input_compatibility(x)
File "C:\Users\Ralph\Miniconda3\envs\carnd-term1\lib\site-packages\keras\engine\topology.py", line 457, in assert_input_compatibility
if K.ndim(x) < ndim:
File "C:\Users\Ralph\Miniconda3\envs\carnd-term1\lib\site-packages\keras\backend\tensorflow_backend.py", line 396, in ndim
dims = x.get_shape()._dims
AttributeError: 'function' object has no attribute 'get_shape'
Press any key to continue . . .
我是否必须更改架构,或输入维度的推断是否存在问题?
答案 0 :(得分:1)
我挖了这个;我相信当你调用set_learning_phase时,它是最新版本中的一个错误;我提交了https://github.com/fchollet/keras/issues/5268和https://github.com/fchollet/keras/pull/5269
您可以回滚到1.2.0作为临时修复。
答案 1 :(得分:0)
为我工作:
input_tensor = Input((3, 224, 224))
model = keras.applications.vgg16.VGG16(input_tensor=input_tensor, include_top=False)
x = model.output
x = model.output
x = Flatten()(x)
x = Dense(1024, activation='relu')(x)
x = Dropout(0.5)(x)
x = Dense(512,activation='relu')(x)
x = Dropout(0.5)(x)
x = Dense(256,activation='relu')(x)
x = Dropout(0.5)(x)
x = Dense(1)(x)
model2 = Model(model.input, x)
import numpy as np
model2.predict(np.zeros((1, 3, 224,224)))
输出:数组([[0.11504173]],dtype = float32)