需要构建Keras子模型

时间:2019-06-26 11:00:34

标签: python tensorflow keras

我正在创建一个从keras模型继承的python类。

true == '9'

这给了我这个错误:

class MyModel(tf.keras.models.Model):

    def __init__(self, size, input_shape):
        super(MyModel, self).__init__()
        self.layer = tf.keras.layers.Dense(size, input_shape=(input_shape,))

    def call(self, inputs):
        return self.layer(inputs)

model = MyModel(5, 30)
model.summary()

如果我在创建模型后添加一行,则此问题已解决:

ValueError: This model has not yet been built. Build the model first by calling `build()` or calling `fit()` with some data, or specify an `input_shape` argument in the first layer(s) for automatic build.

但这并不是执行此操作的最佳方法。我该如何解决?

1 个答案:

答案 0 :(得分:1)

您可以在构造函数中调用self.build()

类似这样的东西:

class MyModel(tf.keras.models.Model):

    def __init__(self, size, input_shape):
        super(MyModel, self).__init__()
        self.layer = tf.keras.layers.Dense(size, input_shape=(input_shape,))
        self.build(input_shape)

    def call(self, inputs):
        return self.layer(inputs)

model = MyModel(2, (5, 30))
model.summary()