在KERAS中加入/组合两种转移学习模型

时间:2017-08-28 17:50:15

标签: python tensorflow deep-learning keras pre-trained-model

我们如何在KERAS的转移学习中加入/组合两个模型?

我有两个型号: model 1 =我的模型 模型2 =训练模型

我可以通过将模型2作为输入来组合这些模型,然后将其输出传递给模型1,这是传统方式。

然而,我正在以其他方式这样做。我想将模型1作为输入,然后将其输出传递给模型2(即训练的模型1)。

2 个答案:

答案 0 :(得分:3)

这是完全相同的程序,只需确保模型的输出与其他模型的输入具有相同的形状。

from keras.models import Model

output = model2(model1.outputs)
joinedModel = Model(model1.inputs,output)

确保(如果这是你想要的),在编译之前使模型2中的所有图层都有trainable=False,因此训练不会改变已经训练过的模型。

测试代码:

from keras.layers import *
from keras.models import Sequential, Model

#creating model 1 and model 2 -- the "previously existing models"
m1 = Sequential()
m2 = Sequential()
m1.add(Dense(20,input_shape=(50,)))
m1.add(Dense(30))
m2.add(Dense(5,input_shape=(30,)))
m2.add(Dense(11))

#creating model 3, joining the models 
out2 = m2(m1.outputs)
m3 = Model(m1.inputs,out2)

#checking out the results
m3.summary()

#layers in model 3
print("\nthe main model:")
for i in m3.layers:
    print(i.name)

#layers inside the last layer of model 3
print("\ninside the submodel:")
for i in m3.layers[-1].layers:
    print(i.name)

<强>输出:

Layer (type)                 Output Shape              Param #   
=================================================================
dense_21_input (InputLayer)  (None, 50)                0         
_________________________________________________________________
dense_21 (Dense)             (None, 20)                1020      
_________________________________________________________________
dense_22 (Dense)             (None, 30)                630       
_________________________________________________________________
sequential_12 (Sequential)   (None, 11)                221       
=================================================================
Total params: 1,871
Trainable params: 1,871
Non-trainable params: 0
_________________________________________________________________

the main model:
dense_21_input
dense_21
dense_22
sequential_12

inside the submodel:
dense_23
dense_24

答案 1 :(得分:0)

问题已经解决。

我使用了using System; using System.Linq; using System.Globalization; // ... // Sample input char. char c = (char)0x20; // space // The set of Unicode character categories containing non-rendering, // unknown, or incomplete characters. // !! Unicode.Format and Unicode.PrivateUse can NOT be included in // !! this set, because they may (private-use) or do (format) // !! contain at least *some* rendering characters. var nonRenderingCategories = new UnicodeCategory[] { UnicodeCategory.Control, UnicodeCategory.OtherNotAssigned, UnicodeCategory.Surrogate }; // Char.IsWhiteSpace() includes the ASCII whitespace characters that // are categorized as control characters. Any other character is // printable, unless it falls into the non-rendering categories. var isPrintable = Char.IsWhiteSpace(c) || ! nonRenderingCategories.Contains(Char.GetUnicodeCategory(c)); 函数,然后添加了模型1和模型2的所有必需图层。

以下代码将在模型1之后添加模型2的第一个 10层

model.add()