模型保存期间,自定义关注层返回AttributeError

时间:2020-05-12 10:00:03

标签: python tensorflow

我正在研究模型,并且使用了这个自定义注意层,

注意:这是一个可重现类似错误的colab示例笔记本,

https://colab.research.google.com/drive/1RDcJwpVbT6JR8_LA52r1nHPSK0w1HuY7?usp=sharing

class AttentionWeightedAverage(Layer):

    def __init__(self, return_attention=False, **kwargs):
        self.init = initializers.get('uniform')
        self.supports_masking = True
        self.return_attention = return_attention
        super(AttentionWeightedAverage, self).__init__(** kwargs)

    def build(self, input_shape):
        self.input_spec = [InputSpec(ndim=3)]
        assert len(input_shape) == 3

        self.w = self.add_weight(shape=(input_shape[2], 1),
                                 name='{}_w'.format(self.name),
                                 initializer=self.init, trainable=True)
        super(AttentionWeightedAverage, self).build(input_shape)

    def call(self, h, mask=None):
        h_shape = K.shape(h)
        d_w, T = h_shape[0], h_shape[1]

        logits = K.dot(h, self.w)  # w^T h
        logits = K.reshape(logits, (d_w, T))
        alpha = K.exp(logits - K.max(logits, axis=-1, keepdims=True))  # exp

        # masked timesteps have zero weight
        if mask is not None:
            mask = K.cast(mask, K.floatx())
            alpha = alpha * mask

        alpha = alpha / (K.sum(alpha, axis=1, keepdims=True) + K.epsilon()) # softmax
        r = K.sum(h * K.expand_dims(alpha), axis=1)  # r = h*alpha^T
        h_star = K.tanh(r)  # h^* = tanh(r)
        if self.return_attention:
            return [h_star, alpha]
        return h_star

    def get_output_shape_for(self, input_shape):
        return self.compute_output_shape(input_shape)

    def compute_output_shape(self, input_shape):
        output_len = input_shape[2]
        if self.return_attention:
            return [(input_shape[0], output_len), (input_shape[0], input_shape[1])]
        return (input_shape[0], output_len)

    def compute_mask(self, input, input_mask=None):
        if isinstance(input_mask, list):
            return [None] * len(input_mask)
        else:
            return None

我的模型架构如下所示

dense()(x)
Bidirectional(lstm(return_sequences=True))(x)
attentionweightedaverage()(x)
dense(1, 'softmax')

经过几次训练后,当我尝试保存模型时,出现以下错误,我认为这与我使用的自定义注意层有关。 我不知道。任何帮助表示赞赏。

仅当我尝试使用model.save保存整个模型并且如果我使用model.save_weights()时,才会发生以下错误。

我正在使用tensorflow 2.1.0

这是回溯,

Traceback (most recent call last):
  File "classifiers/main.py", line 26, in <module>
    main()
  File "classifiers/main.py", line 18, in main
    clf.model.save(f'./classifiers/saved_models/{args.model_name}')
  File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\keras\engine\network.p
    signatures, options)
  File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\keras\saving\save.py",
    signatures, options)
  File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\keras\saving\saved_mod
    save_lib.save(model, filepath, signatures, options)
  File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\saved_model\save.py", 
    checkpoint_graph_view)
  File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\saved_model\signature_
    functions = saveable_view.list_functions(saveable_view.root)
  File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\saved_model\save.py", 
    self._serialization_cache)
  File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\keras\engine\base_laye
    .list_functions_for_serialization(serialization_cache))
  File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\keras\saving\saved_mod
    fns = self.functions_to_serialize(serialization_cache)
  File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\keras\saving\saved_mod
    serialization_cache).functions_to_serialize)
  File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\keras\saving\saved_mod
    serialization_cache)
  File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\keras\saving\saved_mod
    serialization_cache))
  File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\keras\saving\saved_mod
    functions = save_impl.wrap_layer_functions(self.obj, serialization_cache)
  File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\keras\saving\saved_mod
    original_fns = _replace_child_layer_functions(layer, serialization_cache)
  File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\keras\saving\saved_mod
    serialization_cache).functions)
  File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\keras\saving\saved_mod
    serialization_cache)
  File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\keras\saving\saved_mod
    functions = save_impl.wrap_layer_functions(self.obj, serialization_cache)
  File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\keras\saving\saved_mod
    '{}_layer_call_and_return_conditional_losses'.format(layer.name))
  File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\keras\saving\saved_mod
    self.add_trace(*self._input_signature)
  File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\keras\saving\saved_mod
    fn.get_concrete_function(*args, **kwargs)
  File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\keras\saving\saved_mod
    return super(LayerCall, self).get_concrete_function(*args, **kwargs)
  File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\eager\def_function.py"
    self._initialize(args, kwargs, add_initializers_to=initializers)
  File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\eager\def_function.py"
    *args, **kwds))
  File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\eager\function.py", li
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\eager\function.py", li
    graph_function = self._create_graph_function(args, kwargs)
  File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\eager\function.py", li
    capture_by_value=self._capture_by_value),
  File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\framework\func_graph.p
    func_outputs = python_func(*func_args, **func_kwargs)
  File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\eager\def_function.py"
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
    return layer_call(inputs, *args, **kwargs), layer.get_losses_for(inputs)
  File "C:\Users\user\miniconda3\envs\user\lib\site-packages\classifiers\blstm_attention.py", line 43, in 
call
    logits = K.dot(h, self.w)  # w^T h
  File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\keras\backend.py", line 1653, in dot
    if ndim(x) is not None and (ndim(x) > 2 or ndim(y) > 2):
  File "C:\Users\user\miniconda3\envs\user\lib\site-packages\tensorflow_core\python\keras\backend.py", line 1202, in ndim
    dims = x.shape._dims
AttributeError: 'list' object has no attribute 'shape'

0 个答案:

没有答案