在load_model时使用lambda问题激活Keras

时间:2017-06-19 12:15:28

标签: keras

我正在尝试使用参数'axis'执行softmax,我发现的唯一方法是通过函数lambda。这是我的代码,包含一个带有lambda的激活层,用于softmax:

from keras.models import Model
from keras.layers import Input,Dense,Reshape,Activation
from keras.layers.merge import Multiply,Concatenate
from keras.layers.core import Lambda
from keras.activations import softmax
from keras import backend as K
import numpy as np

N = 6
M = 6
T = 1000
H = 5

# Toy input creation
input = np.concatenate([np.random.normal(np.random.rand(1)[0],1.,(1,N,M)) for t in range(T)],axis=0)
input2 = np.random.rand(T,N,M)
input3 = np.random.rand(T,N,M)
input4 = np.random.rand(T,N,M)
a = np.mean(np.reshape(input,(T,N*M)),axis=1)
a = np.maximum(0.,np.minimum(a,0.9999))
a = np.floor(a*3).astype(int)
a = np.stack([a for i in range(M)],axis=1)
a = np.stack([a for i in range(N)],axis=2)
mix1 = np.concatenate((input2[:,:2,:],input3[:,2:4,:],input4[:,4:,:]),axis=1)
mix2 = np.concatenate((input3[:,:2,:],input4[:,2:4,:],input2[:,4:,:]),axis=1)
mix3 = np.concatenate((input4[:,:2,:],input2[:,2:4,:],input3[:,4:,:]),axis=1)
output = np.choose(a,[mix1,mix2,mix3])
images = np.stack((input2,input3,input4),axis=3)

# models definition
# one general model to be trained and
# one mask model to be used later for testing
input_layer = Input(shape=(N,M))
images_input = Input(shape=(N,M,3))
x = Reshape((N*M,))(input_layer)
x = Dense(H, kernel_initializer='uniform', activation='relu')(x)
x = Dense(N*N*3, kernel_initializer='uniform')(x)
x = Reshape((N,N,3))(x)
masks = Activation(activation=lambda y:softmax(y,axis=3))(x)
output_layer = Multiply()([masks,images_input])
output_layer = Lambda(lambda x:K.sum(x,axis=3))(output_layer)
model = Model(inputs=[input_layer,images_input],outputs=output_layer)
mask_model = Model(inputs=input_layer,outputs=masks)

# Compile model
model.compile(loss='mean_squared_error', optimizer='adam')

# Fit the model
history = model.fit([input,images], output, epochs=200, batch_size=50)

#save models
model.save('test.h5')
mask_model.save('mask_test.h5')

它在训练期间工作正常,但是当我尝试加载文件时,它失败了:

from keras.models import load_model
mask_model = load_model('mask_test.h5')

我收到错误:

Traceback (most recent call last):
  File "/home/kresch/general2.py", line 3, in <module>
    mask_model = load_model('mask_test.h5')
  File "/opt/anaconda3/envs/tensorflow/lib/python3.5/site-packages/keras/models.py", line 246, in load_model
    model = model_from_config(model_config, custom_objects=custom_objects)
  File "/opt/anaconda3/envs/tensorflow/lib/python3.5/site-packages/keras/models.py", line 314, in model_from_config
    return layer_module.deserialize(config, custom_objects=custom_objects)
  File "/opt/anaconda3/envs/tensorflow/lib/python3.5/site-packages/keras/layers/__init__.py", line 54, in deserialize
    printable_module_name='layer')
  File "/opt/anaconda3/envs/tensorflow/lib/python3.5/site-packages/keras/utils/generic_utils.py", line 140, in deserialize_keras_object
    list(custom_objects.items())))
  File "/opt/anaconda3/envs/tensorflow/lib/python3.5/site-packages/keras/engine/topology.py", line 2450, in from_config
    process_layer(layer_data)
  File "/opt/anaconda3/envs/tensorflow/lib/python3.5/site-packages/keras/engine/topology.py", line 2419, in process_layer
    custom_objects=custom_objects)
  File "/opt/anaconda3/envs/tensorflow/lib/python3.5/site-packages/keras/layers/__init__.py", line 54, in deserialize
    printable_module_name='layer')
  File "/opt/anaconda3/envs/tensorflow/lib/python3.5/site-packages/keras/utils/generic_utils.py", line 142, in deserialize_keras_object
    return cls.from_config(config['config'])
  File "/opt/anaconda3/envs/tensorflow/lib/python3.5/site-packages/keras/engine/topology.py", line 1242, in from_config
    return cls(**config)
  File "/opt/anaconda3/envs/tensorflow/lib/python3.5/site-packages/keras/layers/core.py", line 287, in __init__
    self.activation = activations.get(activation)
  File "/opt/anaconda3/envs/tensorflow/lib/python3.5/site-packages/keras/activations.py", line 81, in get
    return deserialize(identifier)
  File "/opt/anaconda3/envs/tensorflow/lib/python3.5/site-packages/keras/activations.py", line 73, in deserialize
    printable_module_name='activation function')
  File "/opt/anaconda3/envs/tensorflow/lib/python3.5/site-packages/keras/utils/generic_utils.py", line 160, in deserialize_keras_object
    ':' + function_name)
ValueError: Unknown activation function:<lambda>

Process finished with exit code 1

同样的情况发生在:

model = load_model('test.h5')

我使用lambda函数错了吗?或者(更好)有没有办法可以避免使用lambda函数?

1 个答案:

答案 0 :(得分:0)

尝试自定义激活层,然后加载模型。

load_model('test.h5',custom_objects=activation_layer)