我正在研究模型,并且使用了这个自定义注意层,
注意:这是一个可重现类似错误的colab示例笔记本,
https://colab.research.google.com/drive/1RDcJwpVbT6JR8_LA52r1nHPSK0w1HuY7?usp=sharing
class AttentionWeightedAverage(Layer):
def __init__(self, return_attention=False, **kwargs):
self.init = initializers.get('uniform')
self.supports_masking = True
self.return_attention = return_attention
super(AttentionWeightedAverage, self).__init__(** kwargs)
def build(self, input_shape):
self.input_spec = [InputSpec(ndim=3)]
assert len(input_shape) == 3
self.w = self.add_weight(shape=(input_shape[2], 1),
name='{}_w'.format(self.name),
initializer=self.init, trainable=True)
super(AttentionWeightedAverage, self).build(input_shape)
def call(self, h, mask=None):
h_shape = K.shape(h)
d_w, T = h_shape[0], h_shape[1]
logits = K.dot(h, self.w) # w^T h
logits = K.reshape(logits, (d_w, T))
alpha = K.exp(logits - K.max(logits, axis=-1, keepdims=True)) # exp
# masked timesteps have zero weight
if mask is not None:
mask = K.cast(mask, K.floatx())
alpha = alpha * mask
alpha = alpha / (K.sum(alpha, axis=1, keepdims=True) + K.epsilon()) # softmax
r = K.sum(h * K.expand_dims(alpha), axis=1) # r = h*alpha^T
h_star = K.tanh(r) # h^* = tanh(r)
if self.return_attention:
return [h_star, alpha]
return h_star
def get_output_shape_for(self, input_shape):
return self.compute_output_shape(input_shape)
def compute_output_shape(self, input_shape):
output_len = input_shape[2]
if self.return_attention:
return [(input_shape[0], output_len), (input_shape[0], input_shape[1])]
return (input_shape[0], output_len)
def compute_mask(self, input, input_mask=None):
if isinstance(input_mask, list):
return [None] * len(input_mask)
else:
return None
我的模型架构如下所示
dense()(x)
Bidirectional(lstm(return_sequences=True))(x)
attentionweightedaverage()(x)
dense(1, 'softmax')
经过几次训练后,当我尝试保存模型时,出现以下错误,我认为这与我使用的自定义注意层有关。 我不知道。任何帮助表示赞赏。
仅当我尝试使用model.save
保存整个模型并且如果我使用model.save_weights()
时,才会发生以下错误。
我正在使用tensorflow 2.1.0
这是回溯,
Traceback (most recent call last):
File "classifiers/main.py", line 26, in <module>
main()
File "classifiers/main.py", line 18, in main
clf.model.save(f'./classifiers/saved_models/{args.model_name}')
File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\keras\engine\network.p
signatures, options)
File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\keras\saving\save.py",
signatures, options)
File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\keras\saving\saved_mod
save_lib.save(model, filepath, signatures, options)
File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\saved_model\save.py",
checkpoint_graph_view)
File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\saved_model\signature_
functions = saveable_view.list_functions(saveable_view.root)
File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\saved_model\save.py",
self._serialization_cache)
File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\keras\engine\base_laye
.list_functions_for_serialization(serialization_cache))
File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\keras\saving\saved_mod
fns = self.functions_to_serialize(serialization_cache)
File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\keras\saving\saved_mod
serialization_cache).functions_to_serialize)
File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\keras\saving\saved_mod
serialization_cache)
File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\keras\saving\saved_mod
serialization_cache))
File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\keras\saving\saved_mod
functions = save_impl.wrap_layer_functions(self.obj, serialization_cache)
File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\keras\saving\saved_mod
original_fns = _replace_child_layer_functions(layer, serialization_cache)
File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\keras\saving\saved_mod
serialization_cache).functions)
File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\keras\saving\saved_mod
serialization_cache)
File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\keras\saving\saved_mod
functions = save_impl.wrap_layer_functions(self.obj, serialization_cache)
File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\keras\saving\saved_mod
'{}_layer_call_and_return_conditional_losses'.format(layer.name))
File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\keras\saving\saved_mod
self.add_trace(*self._input_signature)
File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\keras\saving\saved_mod
fn.get_concrete_function(*args, **kwargs)
File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\keras\saving\saved_mod
return super(LayerCall, self).get_concrete_function(*args, **kwargs)
File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\eager\def_function.py"
self._initialize(args, kwargs, add_initializers_to=initializers)
File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\eager\def_function.py"
*args, **kwds))
File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\eager\function.py", li
graph_function, _, _ = self._maybe_define_function(args, kwargs)
File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\eager\function.py", li
graph_function = self._create_graph_function(args, kwargs)
File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\eager\function.py", li
capture_by_value=self._capture_by_value),
File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\framework\func_graph.p
func_outputs = python_func(*func_args, **func_kwargs)
File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\eager\def_function.py"
return weak_wrapped_fn().__wrapped__(*args, **kwds)
return layer_call(inputs, *args, **kwargs), layer.get_losses_for(inputs)
File "C:\Users\user\miniconda3\envs\user\lib\site-packages\classifiers\blstm_attention.py", line 43, in
call
logits = K.dot(h, self.w) # w^T h
File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\keras\backend.py", line 1653, in dot
if ndim(x) is not None and (ndim(x) > 2 or ndim(y) > 2):
File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\keras\backend.py", line 1202, in ndim
dims = x.shape._dims
AttributeError: 'list' object has no attribute 'shape'