最小示例如下所示
from keras.models import Model
from keras.layers import Dense, Input, Lambda, Concatenate
from keras.models import load_model
from keras.optimizers import Adam
def build_model_with_loop():
x = Input(shape=(22,), name='inputs')
# 2 branches
xls = [[] for i in range(2)]
branch = [[0,10],[10,22]]
for _i, (b, e) in enumerate(branch):
xls[_i] = Lambda(lambda x: x[:, b:e])(x)
c = Concatenate()(xls)
y = Dense(1)(c)
model = Model(inputs=x, outputs=y)
model.compile(loss='mse', optimizer=Adam(1E-3))
return model
def build_model_without_loop():
x = Input(shape=(22,), name='inputs')
# 2 branches
xls = [[] for i in range(2)]
xls[0] = Lambda(lambda x: x[:, 0:10])(x)
xls[1] = Lambda(lambda x: x[:, 10:22])(x)
c = Concatenate()(xls)
y = Dense(1)(c)
model = Model(inputs=x, outputs=y)
model.compile(loss='mse', optimizer=Adam(1E-3))
return model
model = build_model_without_loop()
model.save('model_test.h5')
model = load_model('model_test.h5')
该模型非常简单。输入形状为[batch_size, 22]
,模型将首先将输入分成两个分别为形状[batch_size, 10]
和[batch_size, 12]
的分支。在这里,我们在喀拉拉邦使用Lambda
层进行分隔。但是,如果我们在模型的定义中使用for
循环,则model_test.h5
无法正确加载保存的模型load_model
。错误消息显示如下
Traceback (most recent call last):
File "/home/junjiechen/.pyenv/versions/3.6.5/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1659, in _create_c_op
c_op = c_api.TF_FinishOperation(op_desc)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Dimension 0 in both shapes must be equal, but are 24 and 22. Shapes are [24,1] and [22,1]. for 'Assign' (op: 'Assign') with input shapes: [24,1], [22,1].
ValueError: Dimension 0 in both shapes must be equal, but are 24 and 22. Shapes are [24,1] and [22,1]. for 'Assign' (op: 'Assign') with input shapes: [24,1], [22,1].
但是,如果我们使用build_model_without_loop
,则一切正常。如何解决此问题并正确使用模型定义中的循环?
实际上,问题也可能来自Lambda
层。如果两个分支的维数为[batch_size, 11]
,则不会发生错误。
答案 0 :(得分:0)
解决方案是在arguments
层中使用Lambda
选项,以便可以正确传递相应的参数。实际上,上面显示的问题是由于python的lambda
函数引起的。
def build_model_with_loop():
x = Input(shape=(22,), name='inputs')
# 2 branches
xls = [[] for i in range(2)]
branch = [[0,10],[10,22]]
def get_branch(x, beg, end):
return x[:, beg:end]
for i, (b, e) in enumerate(branch):
xls[i] = Lambda(get_branch, arguments={'beg':b, 'end':e})(x)
c = Concatenate()(xls)
y = Dense(1)(c)
model = Model(inputs=x, outputs=y)
model.compile(loss='mse', optimizer=Adam(1E-3))
return model
在这里我们定义get_branch
函数包装器,并在beg
层中通过end
选项传递arguments
和Lambda
。