我们如何在KERAS的转移学习中加入/组合两个模型?
我有两个型号: model 1 =我的模型 模型2 =训练模型
我可以通过将模型2作为输入来组合这些模型,然后将其输出传递给模型1,这是传统方式。
然而,我正在以其他方式这样做。我想将模型1作为输入,然后将其输出传递给模型2(即训练的模型1)。
答案 0 :(得分:3)
这是完全相同的程序,只需确保模型的输出与其他模型的输入具有相同的形状。
from keras.models import Model
output = model2(model1.outputs)
joinedModel = Model(model1.inputs,output)
确保(如果这是你想要的),在编译之前使模型2中的所有图层都有trainable=False
,因此训练不会改变已经训练过的模型。
测试代码:
from keras.layers import *
from keras.models import Sequential, Model
#creating model 1 and model 2 -- the "previously existing models"
m1 = Sequential()
m2 = Sequential()
m1.add(Dense(20,input_shape=(50,)))
m1.add(Dense(30))
m2.add(Dense(5,input_shape=(30,)))
m2.add(Dense(11))
#creating model 3, joining the models
out2 = m2(m1.outputs)
m3 = Model(m1.inputs,out2)
#checking out the results
m3.summary()
#layers in model 3
print("\nthe main model:")
for i in m3.layers:
print(i.name)
#layers inside the last layer of model 3
print("\ninside the submodel:")
for i in m3.layers[-1].layers:
print(i.name)
<强>输出:强>
Layer (type) Output Shape Param #
=================================================================
dense_21_input (InputLayer) (None, 50) 0
_________________________________________________________________
dense_21 (Dense) (None, 20) 1020
_________________________________________________________________
dense_22 (Dense) (None, 30) 630
_________________________________________________________________
sequential_12 (Sequential) (None, 11) 221
=================================================================
Total params: 1,871
Trainable params: 1,871
Non-trainable params: 0
_________________________________________________________________
the main model:
dense_21_input
dense_21
dense_22
sequential_12
inside the submodel:
dense_23
dense_24
答案 1 :(得分:0)
问题已经解决。
我使用了using System;
using System.Linq;
using System.Globalization;
// ...
// Sample input char.
char c = (char)0x20; // space
// The set of Unicode character categories containing non-rendering,
// unknown, or incomplete characters.
// !! Unicode.Format and Unicode.PrivateUse can NOT be included in
// !! this set, because they may (private-use) or do (format)
// !! contain at least *some* rendering characters.
var nonRenderingCategories = new UnicodeCategory[] {
UnicodeCategory.Control,
UnicodeCategory.OtherNotAssigned,
UnicodeCategory.Surrogate };
// Char.IsWhiteSpace() includes the ASCII whitespace characters that
// are categorized as control characters. Any other character is
// printable, unless it falls into the non-rendering categories.
var isPrintable = Char.IsWhiteSpace(c) ||
! nonRenderingCategories.Contains(Char.GetUnicodeCategory(c));
函数,然后添加了模型1和模型2的所有必需图层。
以下代码将在模型1之后添加模型2的第一个 10层。
model.add()