我已经按照下面的步骤-https://github.com/basveeling/wavenet实施了Wavenet的最小示例。
问题在于,模型使用了自定义层,该层在训练过程中可以正常工作,但是一旦重新加载模型,即使我使用的是自定义对象,Keras也无法找到因果层。 < / p>
我正在使用 tensorflow 1.13 和 keras 2.2.4
这里是对象的前三个键/值对的示例。
objects = {'initial_causal_conv': <class 'wavenet_utils.CausalConv1D'>,
'dilated_conv_1_tanh_s0': <class 'wavenet_utils.CausalConv1D'>,
'dilated_conv_1_sigm_s0': <class 'wavenet_utils.CausalConv1D'>,
'...': <class 'wavenet_utils.CausalConv1D'>,
'...': <class 'wavenet_utils.CausalConv1D'>}
model.fit(x=[x_tr1, x_tr2],
y=y_tr1,
epochs=epochs,
batch_size=batch_size,
validation_data=([x_vl1, x_vl2], y_vl1),
callbacks=[checkpoint, early_stopping],
verbose=verbose,
shuffle=True,
class_weight=class_weight)
model = load_model('model.h5', custom_objects=objects)
然后返回此错误:
Traceback (most recent call last):
File "/home/xxx/PycharmProjects/WAVE/DATA_NN.py", line 48, in <module>
objects=objects)
File "/home/xxx/PycharmProjects/WAVE/functions.py", line 572, in run_neural_net
model = load_model('model_conv.h5', custom_objects=objects)
File "/home/xxx/PycharmProjects/WAVE/venv/lib/python3.6/site-packages/keras/engine/saving.py", line 419, in load_model
model = _deserialize_model(f, custom_objects, compile)
File "/home/xxx/PycharmProjects/WAVE/venv/lib/python3.6/site-packages/keras/engine/saving.py", line 225, in _deserialize_model
model = model_from_config(model_config, custom_objects=custom_objects)
File "/home/xxx/PycharmProjects/WAVE/venv/lib/python3.6/site-packages/keras/engine/saving.py", line 458, in model_from_config
return deserialize(config, custom_objects=custom_objects)
File "/home/xxx/PycharmProjects/WAVE/venv/lib/python3.6/site-packages/keras/layers/__init__.py", line 55, in deserialize
printable_module_name='layer')
File "/home/xxx/PycharmProjects/WAVE/venv/lib/python3.6/site-packages/keras/utils/generic_utils.py", line 145, in deserialize_keras_object
list(custom_objects.items())))
File "/home/xxx/PycharmProjects/WAVE/venv/lib/python3.6/site-packages/keras/engine/network.py", line 1022, in from_config
process_layer(layer_data)
File "/home/xxx/PycharmProjects/WAVE/venv/lib/python3.6/site-packages/keras/engine/network.py", line 1008, in process_layer
custom_objects=custom_objects)
File "/home/xxx/PycharmProjects/WAVE/venv/lib/python3.6/site-packages/keras/layers/__init__.py", line 55, in deserialize
printable_module_name='layer')
File "/home/xxx/PycharmProjects/WAVE/venv/lib/python3.6/site-packages/keras/utils/generic_utils.py", line 138, in deserialize_keras_object
': ' + class_name)
ValueError: Unknown layer: CausalConv1D
构建模型时,必须从wavenet_utils.py导入CausalConv1D
下面是完整的 build_model 功能 这是 wavenet_utils,其中包含类CausalConv1D :
from keras.layers import Conv1D
from keras.utils.conv_utils import conv_output_length
import tensorflow as tf
class CausalConv1D(Conv1D):
def __init__(self, filters, kernel_size, init='glorot_uniform', activation=None,
padding='valid', strides=1, dilation_rate=1, bias_regularizer=None,
activity_regularizer=None, kernel_constraint=None, bias_constraint=None, use_bias=True, causal=False,
output_dim=1,
**kwargs):
self.output_dim = output_dim
super(CausalConv1D, self).__init__(filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
dilation_rate=dilation_rate,
activation=activation,
use_bias=use_bias,
kernel_initializer=init,
activity_regularizer=activity_regularizer,
bias_regularizer=bias_regularizer,
kernel_constraint=kernel_constraint,
bias_constraint=bias_constraint,
**kwargs)
self.causal = causal
if self.causal and padding != 'valid':
raise ValueError("Causal mode dictates border_mode=valid.")
def build(self, input_shape):
super(CausalConv1D, self).build(input_shape)
def call(self, x):
if self.causal:
def asymmetric_temporal_padding(x, left_pad=1, right_pad=1):
pattern = [[0, 0], [left_pad, right_pad], [0, 0]]
return tf.pad(x, pattern)
x = asymmetric_temporal_padding(x, self.dilation_rate[0] * (self.kernel_size[0] - 1), 0)
return super(CausalConv1D, self).call(x)
def compute_output_shape(self, input_shape):
input_length = input_shape[1]
if self.causal:
input_length += self.dilation_rate[0] * (self.kernel_size[0] - 1)
length = conv_output_length(input_length,
self.kernel_size[0],
self.padding,
self.strides[0],
dilation=self.dilation_rate[0])
shape = tf.TensorShape(input_shape).as_list()
shape[-1] = self.output_dim
return (input_shape[0], length, self.filters)
def get_config(self):
base_config = super(CausalConv1D, self).get_config()
base_config['output_dim'] = self.output_dim
return base_config
编辑:
我之前也尝试过这种方法。
objects = {'CausalConv1D': <class 'wavenet_utils.CausalConv1D'>}
model.fit(x=[x_tr1, x_tr2],
y=y_tr1,
epochs=epochs,
batch_size=batch_size,
validation_data=([x_vl1, x_vl2], y_vl1),
callbacks=[checkpoint, early_stopping],
verbose=verbose,
shuffle=True,
class_weight=class_weight)
model = load_model('model.h5', custom_objects=objects)
然后返回此错误:
Traceback (most recent call last):
File "/home/xxx/PycharmProjects/WAVE/DATA_NN.py", line 47, in <module>
objects=objects)
File "/home/xxx/PycharmProjects/WAVE/functions.py", line 574, in run_neural_net
model = load_model('model.h5', custom_objects=objects)
File "/home/xxx/PycharmProjects/WAVE/venv/lib/python3.6/site-packages/keras/engine/saving.py", line 419, in load_model
model = _deserialize_model(f, custom_objects, compile)
File "/home/xxx/PycharmProjects/WAVE/venv/lib/python3.6/site-packages/keras/engine/saving.py", line 225, in _deserialize_model
model = model_from_config(model_config, custom_objects=custom_objects)
File "/home/xxx/PycharmProjects/WAVE/venv/lib/python3.6/site-packages/keras/engine/saving.py", line 458, in model_from_config
return deserialize(config, custom_objects=custom_objects)
File "/home/xxx/PycharmProjects/WAVE/venv/lib/python3.6/site-packages/keras/layers/__init__.py", line 55, in deserialize
printable_module_name='layer')
File "/home/xxx/PycharmProjects/WAVE/venv/lib/python3.6/site-packages/keras/utils/generic_utils.py", line 145, in deserialize_keras_object
list(custom_objects.items())))
File "/home/xxx/PycharmProjects/WAVE/venv/lib/python3.6/site-packages/keras/engine/network.py", line 1022, in from_config
process_layer(layer_data)
File "/home/xxx/PycharmProjects/WAVE/venv/lib/python3.6/site-packages/keras/engine/network.py", line 1008, in process_layer
custom_objects=custom_objects)
File "/home/xxx/PycharmProjects/WAVE/venv/lib/python3.6/site-packages/keras/layers/__init__.py", line 55, in deserialize
printable_module_name='layer')
File "/home/xxx/PycharmProjects/WAVE/venv/lib/python3.6/site-packages/keras/utils/generic_utils.py", line 147, in deserialize_keras_object
return cls.from_config(config['config'])
File "/home/xxx/PycharmProjects/WAVE/venv/lib/python3.6/site-packages/keras/engine/base_layer.py", line 1109, in from_config
return cls(**config)
File "/home/xxx/PycharmProjects/WAVE/wavenet_utils.py", line 26, in __init__
**kwargs)
TypeError: __init__() got multiple values for keyword argument 'kernel_initializer'
这可能是这里https://github.com/keras-team/keras/issues/12316中提到的问题吗?
如果是的话,有什么办法解决吗?
答案 0 :(得分:2)
只有一个自定义对象,即CausalConv1D
。
objects = {'CausalConv1D': wavenet_utils.CausalConv1D}
现在,您必须确保get_config
方法正确无误,并具有图层__init__
方法中所需的所有内容。
它缺少causal
属性,并且有一个kernel_initializer
来自您的__init__
方法不支持的基类。
让我们列出您需要的每个属性,然后检查基本配置中的属性:
kernel_initializer
!
kernel_initializer
是您的__init__
方法不支持的配置项init
参数重命名为kernel_initializer
__init__
:def __init__(self, filters, kernel_size,
############## here:
kernel_initializer='glorot_uniform',
#############
activation=None,
padding='valid', strides=1, dilation_rate=1, bias_regularizer=None,
activity_regularizer=None, kernel_constraint=None, bias_constraint=None, use_bias=True, causal=False,
output_dim=1,
**kwargs):
get_config
它必须包含所有不在基类中的__init__
参数:
def get_config(self):
base_config = super(CausalConv1D, self).get_config()
base_config['causal'] = self.causal
base_config['output_dim'] = self.output_dim
return base_config
答案 1 :(得分:1)
以某种方式,到目前为止,我尝试过的任何方法都无法在使用load_model
时正确地加载模型。以下是一个简单的工作,仅保存权重,然后删除现有模型,构建新模型并再次编译,然后 loads保存了权重,即使存在自定义图层,权重也可以正确保存。
model = build_model()
checkpoint = ModelCheckpoint('model.h5', monitor='val_acc',
verbose=1, save_best_only=True, save_weights_only=True, mode='max')
model.fit(x, y)
del model
model = build_model()
model.load_weights('model.h5')
model.predict(x_test)