使用自定义激活功能创建模型时出错

时间:2021-06-17 17:39:35

标签: python tensorflow keras tensorflow2.0 python-3.8

我正在尝试基于 Tanh 激活实现自定义激活函数 (pentanh)。但是,当我将此功能添加到我的模型时,它会引发 ValueError

自定义激活功能:

import tensorflow.keras.backend as K
from tensorflow.keras.layers import Layer


class Pentanh(Layer):

    def __init__(self, **kwargs):
        super(Pentanh, self).__init__(**kwargs)
        self.supports_masking = True
        self.__name__ = 'pentanh'

    def call(self, inputs):
        return K.switch(K.greater(inputs, 0), K.tanh(inputs), 0.25 * K.tanh(inputs))

    def get_config(self):
        return super(Pentanh, self).get_config()

    def compute_output_shape(self, input_shape):
        return input_shape

我在向模型添加 LSTM 层时使用自定义激活函数:

layer_lstm = Bidirectional(LSTM(256, activation="pentanh", return_sequences=True))(layer_embeddings)

而且,在创建模型之前,我更新了 Keras 的自定义对象:

from tensorflow import keras
keras.utils.get_custom_objects().update({'pentanh': Pentanh()})

错误:

File "C:\Users\PycharmProjects\code\util.py", line 106, in create_model
  layer_lstm = Bidirectional(LSTM(256, activation="pentanh", return_sequences=True))(layer_embeddings)
File "C:\Users\env\lib\site-packages\tensorflow\python\keras\layers\wrappers.py", line 418, in __init__
  self.forward_layer = self._recreate_layer_from_config(layer)
File "C:\Users\env\lib\site-packages\tensorflow\python\keras\layers\wrappers.py", line 494, in _recreate_layer_from_config
  return layer.__class__.from_config(config)
File "C:\Users\env\lib\site-packages\tensorflow\python\keras\layers\recurrent.py", line 2882, in from_config
  return cls(**config)
File "C:\Users\env\lib\site-packages\tensorflow\python\keras\layers\recurrent_v2.py", line 1057, in __init__
  super(LSTM, self).__init__(
File "C:\Users\env\lib\site-packages\tensorflow\python\keras\layers\recurrent.py", line 1103, in __init__
  super(DropoutRNNCellMixin, self).__init__(*args, **kwargs)
File "C:\Users\env\lib\site-packages\tensorflow\python\keras\layers\recurrent.py", line 2729, in __init__
  cell = LSTMCell(
File "C:\Users\env\lib\site-packages\tensorflow\python\keras\layers\recurrent.py", line 2324, in __init__
  self.activation = activations.get(activation)
File "C:\Users\env\lib\site-packages\tensorflow\python\util\dispatch.py", line 201, in wrapper
  return target(*args, **kwargs)
File "C:\Users\env\lib\site-packages\tensorflow\python\keras\activations.py", line 531, in get
  return deserialize(identifier)
File "C:\Users\env\lib\site-packages\tensorflow\python\util\dispatch.py", line 201, in wrapper
  return target(*args, **kwargs)
File "C:\Users\env\lib\site-packages\tensorflow\python\keras\activations.py", line 488, in deserialize
  return deserialize_keras_object(
File "C:\Users\env\lib\site-packages\tensorflow\python\keras\utils\generic_utils.py", line 346, in deserialize_keras_object
  (cls, cls_config) = class_and_config_for_serialized_keras_object(
File "C:\Users\env\lib\site-packages\tensorflow\python\keras\utils\generic_utils.py", line 296, in class_and_config_for_serialized_keras_object
  raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)

ValueError: Unknown activation function: Pentanh

1 个答案:

答案 0 :(得分:1)

我真的不知道为什么你的函数只对 LSTM 层不起作用。例如,它适用于密集层。但是,为了解决您的问题,我将其定义为一个函数,它也适用于 LSTM 层。

代码如下:

def my_pentanh(inputs):
    return K.switch(K.greater(inputs, 0), K.tanh(inputs), 0.25 * K.tanh(inputs))

keras.utils.get_custom_objects().update({'my_pentanh': my_pentanh})

和 LSTM 层:

tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(32, return_sequences=True, activation='my_pentanh')),