如何在自定义keras图层中使用keras图层

时间:2019-01-15 07:58:21

标签: tensorflow keras

我正在尝试编写自己的keras层。在这一层中,我想使用其他一些keras层。有什么办法可以做这样的事情:

class MyDenseLayer(tf.keras.layers.Layer):
  def __init__(self, num_outputs):
    super(MyDenseLayer, self).__init__()
    self.num_outputs = num_outputs

  def build(self, input_shape):
    self.fc = tf.keras.layers.Dense(self.num_outputs)

  def call(self, input):
    return self.fc(input)

layer = MyDenseLayer(10)

当我做类似的事情

input = tf.keras.layers.Input(shape = (16,))
output = MyDenseLayer(10)(input)
model = tf.keras.Model(inputs = [input], outputs = [output])
model.summary()

它输出 enter image description here

我如何使密集人群中的训练变得容易?

4 个答案:

答案 0 :(得分:4)

如果您查看有关如何添加自定义图层的文档,则他们建议您使用.add_weight(...)方法。此方法在内部将所有权重放置在self._trainable_weights中。因此,要做您想做的事情,您首先必须定义要使用的keras图层,构建它们,复制权重,然后构建您自己的图层。如果我更新您的代码,则应该是

class mylayer(tf.keras.layers.Layer):
    def __init__(self, num_outputs, num_outputs2):
        self.num_outputs = num_outputs
        super(mylayer, self).__init__()

    def build(self, input_shape):
        self.fc = tf.keras.layers.Dense(self.num_outputs)
        self.fc.build(input_shape)
        self._trainable_weights = self.fc.trainable_weights
        super(mylayer, self).build(input_shape)

    def call(self, input):
        return self.fc(input)

layer = mylayer(10)
input = tf.keras.layers.Input(shape=(16, ))
output = layer(input)
model = tf.keras.Model(inputs=[input], outputs=[output])
model.summary()

然后您应该得到想要的 enter image description here

答案 1 :(得分:1)

将现有图层放入tf.keras.models.Model类要舒适得多且简洁得多。如果定义非自定义图层,例如layers,conv2d,则默认情况下这些图层的参数不可训练。

class MyDenseLayer(tf.keras.Model):
  def __init__(self, num_outputs):
    super(MyDenseLayer, self).__init__()
    self.num_outputs = num_outputs
    self.fc = tf.keras.layers.Dense(num_outputs)

  def call(self, input):
    return self.fc(input)

  def compute_output_shape(self, input_shape):
    shape = tf.TensorShape(input_shape).as_list()
    shape[-1] = self.num_outputs
    return tf.TensorShape(shape)

layer = MyDenseLayer(10)

查看本教程:https://www.tensorflow.org/guide/keras#model_subclassing

答案 2 :(得分:1)

TF2 custom layer Guide中,他们“建议在__init__方法中创建此类子层(由于子层通常具有build方法,因此将在构建外层时构建它们)。”因此,只需将self.fc的创建移到__init__就可以得到想要的。

class MyDenseLayer(tf.keras.layers.Layer):
  def __init__(self, num_outputs):
    super(MyDenseLayer, self).__init__()
    self.num_outputs = num_outputs
    self.fc = tf.keras.layers.Dense(self.num_outputs)

  def build(self, input_shape):
    self.built = True

  def call(self, input):
    return self.fc(input)

input = tf.keras.layers.Input(shape = (16,))
output = MyDenseLayer(10)(input)
model = tf.keras.Model(inputs = [input], outputs = [output])
model.summary()

输出:

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 16)]              0         
_________________________________________________________________
my_dense_layer_2 (MyDenseLay (None, 10)                170       
=================================================================
Total params: 170
Trainable params: 170
Non-trainable params: 0

答案 3 :(得分:0)

这对我有用,而且干净、简洁且可读。

import tensorflow as tf


class MyDense(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(MyDense, self).__init__(kwargs)
        self.dense = tf.keras.layers.Dense(2, tf.keras.activations.relu)

    def call(self, inputs, training=None):
        return self.dense(inputs)


inputs = tf.keras.Input(shape=10)
outputs = MyDense(trainable=True)(inputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs, name='test')
model.compile(loss=tf.keras.losses.MeanSquaredError())
model.summary()

输出:

Model: "test"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 10)]              0         
_________________________________________________________________
my_dense (MyDense)           (None, 2)                 22        
=================================================================
Total params: 22
Trainable params: 22
Non-trainable params: 0
_________________________________________________________________

请注意需要 trainable=True。我已经发布了一个关于它的问题 here