tf.keras.Model的子类无法获得summay()结果

时间:2019-04-15 08:11:37

标签: tensorflow keras tf.keras

我想要构建tf.keras.Model的子类,并希望使用summary函数查看模型结构。但这行不通。以下是我的代码:

import tensorflow as tf

class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(32, 3, activation='relu')
        self.flatten = tf.keras.layers.Flatten()
        self.d1 = tf.keras.layers.Dense(128, activation='relu')
        self.d2 = tf.keras.layers.Dense(10, activation='softmax')

    def call(self, x):
        x = self.conv1(x)
        x = self.flatten(x)
        x = self.d1(x)
        return self.d2(x)

model = MyModel()
model.summary()

错误:

  

ValueError:尚未构建此模型。首先建立模型   通过调用build()或使用某些数据调用fit(),或指定一个   第一层中的input_shape自变量用于自动构建。

3 个答案:

答案 0 :(得分:1)

您需要调用一次每个层以推断形状,然后使用模型的输入形状作为参数来调用build()的{​​{1}}方法:

tf.keras.Model

答案 1 :(得分:1)

编辑@Vlad 的回答以避免此错误ValueError: Input 0 of layer conv2d_10 is incompatible with the layer: : expected min_ndim=4, found ndim=3. Full shape received: (32, 32, 3)

将此行更改为

model.build((32, 32, 3 ))

致:

model.build((None, 32, 32, 3 ))

最终代码:

class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(32, 3, activation='relu')
        self.flatten = tf.keras.layers.Flatten()
        self.d1 = tf.keras.layers.Dense(128, activation='relu')
        self.d2 = tf.keras.layers.Dense(10, activation='softmax')
        x = np.random.normal(size=(1, 32, 32, 3))
        x = tf.convert_to_tensor(x)
        _ = self.call(x)

    def call(self, x):
        x = self.conv1(x)
        x = self.flatten(x)
        x = self.d1(x)
        return self.d2(x)

model = MyModel()
model.build((None, 32, 32, 3 ))
model.summary()

答案 2 :(得分:0)

列出了一个更好的解决方案here。您需要提供一个模型方法来显式推断模型。

import tensorflow as tf
from tensorflow.keras.layers import Input

class MyModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.dense = tf.keras.layers.Dense(1)

    def call(self, inputs, **kwargs):
        return self.dense(inputs)

    def model(self):
        x = Input(shape=(1))
        return Model(inputs=[x], outputs=self.call(x))

MyModel().model().summary()