扩展keras VGG16模型的异常:self.assert_input_compatibility(x)抛出

时间:2017-01-17 18:41:02

标签: deep-learning keras

为了使用VGG16网络进行回归任务,我按以下方式扩展它:

keras.applications.vgg16.VGG16(input_tensor=input_tensor, include_top=False)
        x = model.output
        x = Flatten()(x)
        x = Dense(1024, activation='relu')(x)
        x = Dropout(0.5)(x)
        x = Dense(512,activation='relu')(x)
        x = Dropout(0.5)(x)
        x = Dense(256,activation='relu')(x)
        x = Dropout(0.5)(x)
        x = Dense(1)(x)
        model = Model(model.input, x)

立即抛出以下异常:

  File "C:\Users\Ralph\Documents\GitHub\CarND-Behavioral-Cloning-P3\model.py", line 116, in <module>
    tf.app.run()
  File "C:\Users\Ralph\Miniconda3\envs\carnd-term1\lib\site-packages\tensorflow\python\platform\app.py", line 43, in run
    sys.exit(main(sys.argv[:1] + flags_passthrough))
  File "C:\Users\Ralph\Documents\GitHub\CarND-Behavioral-Cloning-P3\model.py", line 101, in main
    model = create_model()
  File "C:\Users\Ralph\Documents\GitHub\CarND-Behavioral-Cloning-P3\model.py", line 60, in create_model
    x = Dense(1024)(x)
  File "C:\Users\Ralph\Miniconda3\envs\carnd-term1\lib\site-packages\keras\engine\topology.py", line 529, in __call__
    self.assert_input_compatibility(x)
  File "C:\Users\Ralph\Miniconda3\envs\carnd-term1\lib\site-packages\keras\engine\topology.py", line 457, in assert_input_compatibility
    if K.ndim(x) < ndim:
  File "C:\Users\Ralph\Miniconda3\envs\carnd-term1\lib\site-packages\keras\backend\tensorflow_backend.py", line 396, in ndim
    dims = x.get_shape()._dims
AttributeError: 'function' object has no attribute 'get_shape'
Press any key to continue . . .

我是否必须更改架构,或输入维度的推断是否存在问题?

2 个答案:

答案 0 :(得分:1)

我挖了这个;我相信当你调用set_learning_phase时,它是最新版本中的一个错误;我提交了https://github.com/fchollet/keras/issues/5268https://github.com/fchollet/keras/pull/5269

您可以回滚到1.2.0作为临时修复。

答案 1 :(得分:0)

为我工作:

input_tensor = Input((3, 224, 224))
model = keras.applications.vgg16.VGG16(input_tensor=input_tensor, include_top=False)
x = model.output
x = model.output
x = Flatten()(x)
x = Dense(1024, activation='relu')(x)
x = Dropout(0.5)(x)
x = Dense(512,activation='relu')(x)
x = Dropout(0.5)(x)
x = Dense(256,activation='relu')(x)
x = Dropout(0.5)(x)
x = Dense(1)(x)
model2 = Model(model.input, x)
import numpy as np
model2.predict(np.zeros((1, 3, 224,224)))

输出:数组([[0.11504173]],dtype = float32)