我是Keras的新手。
Keras's docs展示了如何制作自定义图层,使您可以完全控制可训练的权重。
我的问题是,如何才能扩展现有层?
例如,BatchNormalization
层没有激活选项,在实践中,通常可以在批处理归一化之后添加激活函数。
此尝试无效:
class BatchNormalizationActivation(keras.layers.BatchNormalization):
def __init__(self, bn_params={}, activation=keras.activations.relu, act_params={}):
super(BatchNormalizationActivation, self).__init__(**bn_params)
self.act = activation
def call(x):
x = super(BatchNormalizationActivation, self).call(x)
return self.act(x, **act_params)
BatchNormalizationActivation()
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-182-6e4f0495a112> in <module>()
----> 1 BatchNormalizationActivation()
<ipython-input-181-2d5c8337234a> in __init__(self, bn_params, activation, act_params)
3
4 def __init__(self, bn_params={}, activation=keras.activations.relu, act_params={}):
----> 5 super(BatchNormalizationActivation, self).__init__(**bn_params)
6 self.act = activation
7
/usr/local/lib/python3.6/site-packages/tensorflow/python/keras/_impl/keras/layers/normalization.py in __init__(self, axis, momentum, epsilon, center, scale, beta_initializer, gamma_initializer, moving_mean_initializer, moving_variance_initializer, beta_regularizer, gamma_regularizer, beta_constraint, gamma_constraint, **kwargs)
105 beta_constraint=constraints.get(beta_constraint),
106 gamma_constraint=constraints.get(gamma_constraint),
--> 107 **kwargs
108 )
109
/usr/local/lib/python3.6/site-packages/tensorflow/python/layers/normalization.py in __init__(self, axis, momentum, epsilon, center, scale, beta_initializer, gamma_initializer, moving_mean_initializer, moving_variance_initializer, beta_regularizer, gamma_regularizer, beta_constraint, gamma_constraint, renorm, renorm_clipping, renorm_momentum, fused, trainable, virtual_batch_size, adjustment, name, **kwargs)
144 **kwargs):
145 super(BatchNormalization, self).__init__(
--> 146 name=name, trainable=trainable, **kwargs)
147 if isinstance(axis, list):
148 self.axis = axis[:]
/usr/local/lib/python3.6/site-packages/tensorflow/python/keras/_impl/keras/engine/base_layer.py in __init__(self, **kwargs)
147 super(Layer, self).__init__(
148 name=name, dtype=dtype, trainable=trainable,
--> 149 activity_regularizer=kwargs.get('activity_regularizer'))
150 self._uses_inputs_arg = True
151
/usr/local/lib/python3.6/site-packages/tensorflow/python/layers/base.py in __init__(self, trainable, name, dtype, activity_regularizer, **kwargs)
130 self._graph = None # Will be set at build time.
131 self._dtype = None if dtype is None else dtypes.as_dtype(dtype).name
--> 132 self._call_fn_args = estimator_util.fn_args(self.call)
133 self._compute_previous_mask = ('mask' in self._call_fn_args or
134 hasattr(self, 'compute_mask'))
/usr/local/lib/python3.6/site-packages/tensorflow/python/estimator/util.py in fn_args(fn)
60 args = tf_inspect.getfullargspec(fn).args
61 if _is_bounded_method(fn):
---> 62 args.remove('self')
63 return tuple(args)
64
ValueError: list.remove(x): x not in list