keras模型子类化示例

时间:2018-10-15 23:46:31

标签: tensorflow keras

从Keras 2.2.0开始,发布了第三个模型定义API:模型子类。

根据常见问题解答:

  

但是,在子类模型中,模型的拓扑定义为   Python代码(而不是作为层的静态图)。这意味着   模型的拓扑无法检查或序列化。结果,   以下方法和属性不适用于子类   型号:

     

model.inputs和model.outputs。 model.to_yaml()和model.to_json()   model.get_config()和model.save()。

保存经过训练的模型进行推理的唯一选择是使用 model.save_weights 方法。但是,我没有运回模型进行推理的运气。遇到的错误消息包括:

  

此模型从未被调用过,因此尚未创建其权重,因此无法显示摘要。首先构建模型(例如,通过对一些测试数据进行调用)。   您试图将包含4层的权重文件加载到具有0层的模型中。   NotImplementedError

任何人都可以举一个完整的玩具示例来创建子类的keras模型,train和save_weights,然后将其加载回进行推断吗?

1 个答案:

答案 0 :(得分:1)

在尝试保存子类化模型权重之前,需要调用tf.keras.Model.build方法。替代方法是在尝试保存模型权重之前,在某些输入上调用tf.keras.Model.fit或tf.keras.Model.fit.call。这同样适用于将权重加载到子类模型的新创建实例中。您需要先调用上述方法之一,然后再尝试加载体重。 这是显示子类模型的权重和加载权重的示例

import tensorflow as tf

print('TensorFlow', tf.__version__)

class ResidualBlock(tf.keras.Model):
    def __init__(self, block_type=None, n_filters=None):
        super(ResidualBlock, self).__init__()
        self.n_filters = n_filters
        if block_type == 'identity':
            self.strides = 1
        elif block_type == 'conv':
            self.strides = 2
            self.conv_shorcut = tf.keras.layers.Conv2D(filters=self.n_filters, 
                               kernel_size=1, 
                               padding='same',
                               strides=self.strides,
                               kernel_initializer='he_normal')
            self.bn_shortcut = tf.keras.layers.BatchNormalization(momentum=0.9)

        self.conv_1 = tf.keras.layers.Conv2D(filters=self.n_filters, 
                               kernel_size=3, 
                               padding='same',
                               strides=self.strides,
                               kernel_initializer='he_normal')
        self.bn_1 = tf.keras.layers.BatchNormalization(momentum=0.9)
        self.relu_1 = tf.keras.layers.ReLU()

        self.conv_2 = tf.keras.layers.Conv2D(filters=self.n_filters, 
                               kernel_size=3, 
                               padding='same', 
                               kernel_initializer='he_normal')
        self.bn_2 = tf.keras.layers.BatchNormalization(momentum=0.9)
        self.relu_2 = tf.keras.layers.ReLU()

    def call(self, x, training=False):
        shortcut = x
        if self.strides == 2:
            shortcut = self.conv_shorcut(x)
            shortcut = self.bn_shortcut(shortcut)
        y = self.conv_1(x)
        y = self.bn_1(y)
        y = self.relu_1(y)
        y = self.conv_2(y)
        y = self.bn_2(y)
        y = tf.add(shortcut, y)
        y = self.relu_2(y)
        return y

class ResNet34(tf.keras.Model):
    def __init__(self, include_top=True, n_classes=1000):
        super(ResNet34, self).__init__()

        self.n_classes = n_classes
        self.include_top = include_top
        self.conv_1 = tf.keras.layers.Conv2D(filters=64, 
                                               kernel_size=7, 
                                               padding='same', 
                                               strides=2, 
                                               kernel_initializer='he_normal')
        self.bn_1 = tf.keras.layers.BatchNormalization(momentum=0.9)
        self.relu_1 = tf.keras.layers.ReLU()
        self.maxpool = tf.keras.layers.MaxPool2D(3, 2, padding='same')
        self.residual_blocks = tf.keras.Sequential()
        for n_filters, reps, downscale in zip([64, 128, 256, 512], 
                                              [3, 4, 6, 3], 
                                              [False, True, True, True]):
            for i in range(reps):
                if i == 0 and downscale:
                    self.residual_blocks.add(ResidualBlock(block_type='conv', 
                                                              n_filters=n_filters))
                else:
                    self.residual_blocks.add(ResidualBlock(block_type='identity', 
                                                              n_filters=n_filters))
        self.GAP = tf.keras.layers.GlobalAveragePooling2D()
        self.fc = tf.keras.layers.Dense(units=self.n_classes)

    def call(self, x, training=False):
        y = self.conv_1(x)
        y = self.bn_1(y)
        y = self.relu_1(y)
        y = self.maxpool(y)
        y = self.residual_blocks(y)
        if self.include_top:
            y = self.GAP(y)
            y = self.fc(y)
        return y

## saving weights
model = ResNet34()
model.build((1, 224, 224, 3))
model.summary()
model.save_weights('model_weights.h5')

## loading saved weights
model_new = ResNet34()
model_new.build((1, 224, 224, 3))
model_new.load_weights('model_weights.h5')