我正在尝试使用两个自定义对象加载模型,并且标题中出现此错误。
这是我导入/定义函数的地方,并且允许keras通过名称引用它们。
from tensorflow.keras.utils import get_custom_objects
from tensorflow.python.keras.layers import LeakyReLU
from tensorflow.keras.layers import Activation
from tensorflow.keras.backend import sigmoid
def swish(x, beta=1):
return x * sigmoid(beta * x)
get_custom_objects().update({'swish': Activation(swish)})
get_custom_objects().update({'lrelu': LeakyReLU()})
我用这部分加载模型
from tensorflow.keras.models import load_model
model = load_model('model.h5', custom_objects={'swish': Activation(swish), 'lrelu': LeakyReLU()}, compile=False)
我收到以下错误:
Traceback (most recent call last):
File "C:\Users\Ben\PycharmProjects\untitled\trainer.py", line 102, in load_items
model = load_model(data_loc + 'model.h5', custom_objects={'swish': Activation(swish), 'lrelu': LeakyReLU()}, compile=False)
File "C:\Users\Ben\PycharmProjects\untitled\venv\lib\site-packages\tensorflow_core\python\keras\saving\save.py", line 146, in load_model
return hdf5_format.load_model_from_hdf5(filepath, custom_objects, compile)
File "C:\Users\Ben\PycharmProjects\untitled\venv\lib\site-packages\tensorflow_core\python\keras\saving\hdf5_format.py", line 168, in load_model_from_hdf5
custom_objects=custom_objects)
File "C:\Users\Ben\PycharmProjects\untitled\venv\lib\site-packages\tensorflow_core\python\keras\saving\model_config.py", line 55, in model_from_config
return deserialize(config, custom_objects=custom_objects)
File "C:\Users\Ben\PycharmProjects\untitled\venv\lib\site-packages\tensorflow_core\python\keras\layers\serialization.py", line 102, in deserialize
printable_module_name='layer')
File "C:\Users\Ben\PycharmProjects\untitled\venv\lib\site-packages\tensorflow_core\python\keras\utils\generic_utils.py", line 191, in deserialize_keras_object
list(custom_objects.items())))
File "C:\Users\Ben\PycharmProjects\untitled\venv\lib\site-packages\tensorflow_core\python\keras\engine\sequential.py", line 369, in from_config
custom_objects=custom_objects)
File "C:\Users\Ben\PycharmProjects\untitled\venv\lib\site-packages\tensorflow_core\python\keras\layers\serialization.py", line 102, in deserialize
printable_module_name='layer')
File "C:\Users\Ben\PycharmProjects\untitled\venv\lib\site-packages\tensorflow_core\python\keras\utils\generic_utils.py", line 193, in deserialize_keras_object
return cls.from_config(cls_config)
File "C:\Users\Ben\PycharmProjects\untitled\venv\lib\site-packages\tensorflow_core\python\keras\engine\base_layer.py", line 594, in from_config
return cls(**config)
File "C:\Users\Ben\PycharmProjects\untitled\venv\lib\site-packages\tensorflow_core\python\keras\layers\core.py", line 361, in __init__
self.activation = activations.get(activation)
File "C:\Users\Ben\PycharmProjects\untitled\venv\lib\site-packages\tensorflow_core\python\keras\activations.py", line 321, in get
identifier, printable_module_name='activation')
File "C:\Users\Ben\PycharmProjects\untitled\venv\lib\site-packages\tensorflow_core\python\keras\utils\generic_utils.py", line 180, in deserialize_keras_object
config, module_objects, custom_objects, printable_module_name)
File "C:\Users\Ben\PycharmProjects\untitled\venv\lib\site-packages\tensorflow_core\python\keras\utils\generic_utils.py", line 165, in class_and_config_for_serialized_keras_object
raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
ValueError: Unknown activation: Activation
也可能值得注意的是,我试图在具有不同环境的不同项目中保存和加载模型。两者都使用tf 2.0.0 gpu。导入都应该相同。
答案 0 :(得分:5)
您不应盲目相信互联网上的每个教程。正如我在评论中所说,问题是通过激活函数作为Layer
(准确地说是Activation
传递),该函数起作用,但不正确,因为在模型保存/加载过程中会遇到问题:
def swish(x, beta = 1):
return (x * K.sigmoid(beta * x))
get_custom_objects().update({'swish': Activation(swish)})
model = Sequential()
model.add(Dense(10, input_shape=(1,), activation="swish"))
上面的代码不是正确的方法,一层内的激活不应是另一层。使用此代码,我在TensorFlow 1.14中使用model.save
和keras
的{{1}}期间出错。正确的方法是:
tf.keras
然后,您将能够正确加载和保存模型。如果您需要将激活添加为图层,则应该执行以下操作:
def swish(x, beta = 1):
return (x * K.sigmoid(beta * x))
get_custom_objects().update({'swish': swish})
model = Sequential()
model.add(Dense(10, input_shape=(1,), activation="swish"))
这也将允许模型保存/加载。
答案 1 :(得分:1)
我只使用以下代码行,效果很好!
activation=tf.nn.swish