使用训练有素的模型层在keras中创建另一个模型

时间:2019-05-06 05:10:18

标签: python tensorflow machine-learning keras

我在Keras中创建了一个模型,如下所示:

    m = Sequential()
    m.add(Dense(912, activation='relu', input_shape=(943, 1)))
    m.add(Dense(728, activation='relu'))
    m.add(Dense(528, activation='relu'))
    m.add(Flatten())
    m.add(Dense(500, activation='relu', name="bottleneck"))
    m.add(Dense(528, activation='relu'))
    m.add(Dense(728, activation='relu'))
    m.add(Dense(943, activation='linear'))

    m.compile(loss='mean_squared_error', optimizer='SGD')
    m.summary()

现在,我要使用bottleneck层并添加以下创建的网络:

    model = Sequential()
    model.add(Dense(930, activation='relu', input_shape=(943, 1)))
    model.add(Flatten())
    model.add(m.get_layer('bottleneck'))
    model.add(m.get_layer('bottleneck'))
    model.add(m.get_layer('bottleneck'))
    model.add(m.get_layer('bottleneck'))
    model.add(Flatten())
    model.add(Dense(100, activation='linear'))

但是在训练模型m之后,在启动错误时引发错误:

ValueError: Input 0 is incompatible with layer bottleneck: expected axis -1 of input shape to have value 497904 but got shape (None, 876990)

1 个答案:

答案 0 :(得分:0)

错误消息试图告诉您,第二个模型中与“瓶颈”层相关的输入形状与第一个模型不同。

为了重新使用一个图层,您需要匹配该图层的输入数量。在您的情况下,第一个模型对此层有497904个输入,但是您尝试在下一个模型中将其与具有876990个输入的输入层一起使用。

我怀疑您想要更多类似的东西(请注意,在每种情况下我都会立即变平,以便我们可以更好地把握每一层的输入数量):

m = Sequential()
m.add(Flatten(input_shape=(943, 1)))
m.add(Dense(912, activation='relu'))
m.add(Dense(728, activation='relu'))
m.add(Dense(528, activation='relu'))
m.add(Dense(500, activation='relu', name="bottleneck"))
m.add(Dense(528, activation='relu'))
m.add(Dense(728, activation='relu'))
m.add(Dense(943, activation='linear'))

m.compile(loss='mean_squared_error', optimizer='SGD')
m.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
flatten (Flatten)            (None, 943)               0         
_________________________________________________________________
dense (Dense)                (None, 912)               860928    
_________________________________________________________________
dense_1 (Dense)              (None, 728)               664664    
_________________________________________________________________
dense_2 (Dense)              (None, 528)               384912    
_________________________________________________________________
bottleneck (Dense)           (None, 500)               264500    
_________________________________________________________________
dense_3 (Dense)              (None, 528)               264528    
_________________________________________________________________
dense_4 (Dense)              (None, 728)               385112    
_________________________________________________________________
dense_5 (Dense)              (None, 943)               687447    
=================================================================
Total params: 3,512,091
Trainable params: 3,512,091
Non-trainable params: 0

请注意,瓶颈层的输入具有(None,528)的形状。现在,在第二个模型中,我们可以执行以下操作:

model = Sequential()
model.add(Dense(930, activation='relu', input_shape=(943, 1)))
model.add(Flatten())
model.add(Dense(528, activation='relu'))
model.add(m.get_layer('bottleneck'))
model.add(Flatten())
model.add(Dense(100, activation='linear'))
model.summary()
Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_9 (Dense)              (None, 943, 930)          1860      
_________________________________________________________________
flatten_3 (Flatten)          (None, 876990)            0         
_________________________________________________________________
dense_10 (Dense)             (None, 528)               463051248 
_________________________________________________________________
bottleneck (Dense)           (None, 500)               264500    
_________________________________________________________________
flatten_4 (Flatten)          (None, 500)               0         
_________________________________________________________________
dense_11 (Dense)             (None, 100)               50100     
=================================================================
Total params: 463,367,708
Trainable params: 463,367,708
Non-trainable params: 0