Keras继承了内置层?

时间:2018-08-25 11:03:07

标签: python tensorflow keras

我是Keras的新手。

Keras's docs展示了如何制作自定义图层,使您可以完全控制可训练的权重。

我的问题是,如何才能扩展现有层?

例如,BatchNormalization层没有激活选项,在实践中,通常可以在批处理归一化之后添加激活函数。

此尝试无效:

class BatchNormalizationActivation(keras.layers.BatchNormalization):

    def __init__(self, bn_params={}, activation=keras.activations.relu, act_params={}):
        super(BatchNormalizationActivation, self).__init__(**bn_params)
        self.act = activation

    def call(x):
        x = super(BatchNormalizationActivation, self).call(x)
        return self.act(x, **act_params)




BatchNormalizationActivation()

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-182-6e4f0495a112> in <module>()
----> 1 BatchNormalizationActivation()

<ipython-input-181-2d5c8337234a> in __init__(self, bn_params, activation, act_params)
      3 
      4     def __init__(self, bn_params={}, activation=keras.activations.relu, act_params={}):
----> 5         super(BatchNormalizationActivation, self).__init__(**bn_params)
      6         self.act = activation
      7 

/usr/local/lib/python3.6/site-packages/tensorflow/python/keras/_impl/keras/layers/normalization.py in __init__(self, axis, momentum, epsilon, center, scale, beta_initializer, gamma_initializer, moving_mean_initializer, moving_variance_initializer, beta_regularizer, gamma_regularizer, beta_constraint, gamma_constraint, **kwargs)
    105         beta_constraint=constraints.get(beta_constraint),
    106         gamma_constraint=constraints.get(gamma_constraint),
--> 107         **kwargs
    108     )
    109 

/usr/local/lib/python3.6/site-packages/tensorflow/python/layers/normalization.py in __init__(self, axis, momentum, epsilon, center, scale, beta_initializer, gamma_initializer, moving_mean_initializer, moving_variance_initializer, beta_regularizer, gamma_regularizer, beta_constraint, gamma_constraint, renorm, renorm_clipping, renorm_momentum, fused, trainable, virtual_batch_size, adjustment, name, **kwargs)
    144                **kwargs):
    145     super(BatchNormalization, self).__init__(
--> 146         name=name, trainable=trainable, **kwargs)
    147     if isinstance(axis, list):
    148       self.axis = axis[:]

/usr/local/lib/python3.6/site-packages/tensorflow/python/keras/_impl/keras/engine/base_layer.py in __init__(self, **kwargs)
    147     super(Layer, self).__init__(
    148         name=name, dtype=dtype, trainable=trainable,
--> 149         activity_regularizer=kwargs.get('activity_regularizer'))
    150     self._uses_inputs_arg = True
    151 

/usr/local/lib/python3.6/site-packages/tensorflow/python/layers/base.py in __init__(self, trainable, name, dtype, activity_regularizer, **kwargs)
    130     self._graph = None  # Will be set at build time.
    131     self._dtype = None if dtype is None else dtypes.as_dtype(dtype).name
--> 132     self._call_fn_args = estimator_util.fn_args(self.call)
    133     self._compute_previous_mask = ('mask' in self._call_fn_args or
    134                                    hasattr(self, 'compute_mask'))

/usr/local/lib/python3.6/site-packages/tensorflow/python/estimator/util.py in fn_args(fn)
     60     args = tf_inspect.getfullargspec(fn).args
     61     if _is_bounded_method(fn):
---> 62       args.remove('self')
     63   return tuple(args)
     64 

ValueError: list.remove(x): x not in list

0 个答案:

没有答案