在keras中,如何克隆带有自定义对象的模型?

时间:2019-04-18 21:09:46

标签: tensorflow keras

我有一个带有自定义激活的模型。结果,

model2 = keras.models.clone_model(model)

给出一个错误。我可以使用 custom_objects 关键字加载保存的模型,但是在 clone_model 上看不到这样的选项。除了重建模型和传递权重之外,还有其他方法吗?

编辑:

这是示例代码(玩具问题):

import tensorflow.keras as keras
import tensorflow.keras.backend as K

def myTanh(x):
    return K.tanh(x)

inp = keras.Input(shape=(10,10,1))
flat = keras.layers.Flatten()(inp)
out = keras.layers.Dense(20, activation=myTanh)(flat)
model = keras.Model(inp,out)                                                                 
model.compile(optimizer=keras.optimizers.Adam(lr=0.001),loss='categorical_crossentropy')

model2 = keras.models.clone_model(model)

错误转储:

~/.conda/envs/tf-gpu/lib/python3.6/site-packages/tensorflow/python/keras/models.py in clone_model(model, input_tensors)
    269     return _clone_sequential_model(model, input_tensors=input_tensors)
    270   else:
--> 271     return _clone_functional_model(model, input_tensors=input_tensors)
    272 
    273 

~/.conda/envs/tf-gpu/lib/python3.6/site-packages/tensorflow/python/keras/models.py in _clone_functional_model(model, input_tensors)
    129       if layer not in layer_map:
    130         # Clone layer.
--> 131         new_layer = layer.__class__.from_config(layer.get_config())
    132         layer_map[layer] = new_layer
    133         layer = new_layer

~/.conda/envs/tf-gpu/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py in from_config(cls, config)
    400         A layer instance.
    401     """
--> 402     return cls(**config)
    403 
    404   def compute_output_shape(self, input_shape):

~/.conda/envs/tf-gpu/lib/python3.6/site-packages/tensorflow/python/keras/layers/core.py in __init__(self, units, activation, use_bias, kernel_initializer, bias_initializer, kernel_regularizer, bias_regularizer, activity_regularizer, kernel_constraint, bias_constraint, **kwargs)
    920         activity_regularizer=regularizers.get(activity_regularizer), **kwargs)
    921     self.units = int(units)
--> 922     self.activation = activations.get(activation)
    923     self.use_bias = use_bias
    924     self.kernel_initializer = initializers.get(kernel_initializer)

~/.conda/envs/tf-gpu/lib/python3.6/site-packages/tensorflow/python/keras/activations.py in get(identifier)
    209   if isinstance(identifier, six.string_types):
    210     identifier = str(identifier)
--> 211     return deserialize(identifier)
    212   elif callable(identifier):
    213     return identifier

~/.conda/envs/tf-gpu/lib/python3.6/site-packages/tensorflow/python/keras/activations.py in deserialize(name, custom_objects)
    200       module_objects=globals(),
    201       custom_objects=custom_objects,
--> 202       printable_module_name='activation function')
    203 
    204 

~/.conda/envs/tf-gpu/lib/python3.6/site-packages/tensorflow/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
    210       if fn is None:
    211         raise ValueError('Unknown ' + printable_module_name + ':' +
--> 212                          function_name)
    213     return fn
    214   else:

ValueError: Unknown activation function:myTanh

2 个答案:

答案 0 :(得分:1)

我通过致电解决了问题

keras.utils.get_custom_objects().update(custom_objects)

在定义了其他对象之后,keras必须知道这些对象才能正确地克隆模型。

def lrelu(x, alpha=0.2):
    return tf.nn.relu(x) * (1 - alpha) + x * alpha

custom_object = {
    'lrelu': lrelu,
}
keras.utils.get_custom_objects().update(custom_objects)

答案 1 :(得分:0)

这是Keras中的开放bug

建议的解决方法是使用Lambda层而不是Activation层。

x = keras.layers.Lambda(my_custom_activation_function)(x)