模型不连续时节省权重

时间:2020-03-26 00:30:36

标签: tensorflow machine-learning keras

我的模型生成器定义为


class Generator(tf.keras.Model):
    def __init__(self, channels=3, method='transpose'):
        super(Generator, self).__init__()
        self.channels = channels
        self.method = method

        self.dense = tf.keras.layers.Dense(512 * 8 * 8, use_bias=False)

        self.reshape = tf.keras.layers.Reshape((8, 8, 512))


        self.convT_1 = tf.keras.layers.Conv2DTranspose(256, (3, 3), padding='same', use_bias=False)
        self.convT_2 = tf.keras.layers.Conv2DTranspose(256, (3, 3), strides=(2, 2), padding='same', use_bias=False)
        self.convT_3 = tf.keras.layers.Conv2DTranspose(128, (3, 3), strides=(2, 2), padding='same', use_bias=False)
        self.convT_4 = tf.keras.layers.Conv2DTranspose(128,(3, 3), strides=(2, 2), padding='same', use_bias=False)
        self.convT_5 = tf.keras.layers.Conv2DTranspose(self.channels, (4, 4), strides=(2, 2), padding='same', use_bias=False, activation='tanh')

        self.batch_norm_1 = tf.keras.layers.BatchNormalization()
        self.batch_norm_2 = tf.keras.layers.BatchNormalization()
        self.batch_norm_3 = tf.keras.layers.BatchNormalization()
        self.batch_norm_4 = tf.keras.layers.BatchNormalization()
        self.batch_norm_5 = tf.keras.layers.BatchNormalization()

        self.leakyrelu_1 = tf.keras.layers.LeakyReLU()
        self.leakyrelu_2 = tf.keras.layers.LeakyReLU()
        self.leakyrelu_3 = tf.keras.layers.LeakyReLU()
        self.leakyrelu_4 = tf.keras.layers.LeakyReLU()
        self.leakyrelu_5 = tf.keras.layers.LeakyReLU()

    def call(self, inputs, training=True):


            x = self.dense(inputs)
            x = self.batch_norm_1(x, training)
            x = self.leakyrelu_1(x)
            x = self.reshape(x)


            x = self.convT_1(x)
            x = self.batch_norm_2(x, training)
            x = self.leakyrelu_2(x)

            x = self.convT_2(x)
            x = self.batch_norm_3(x, training)
            x = self.leakyrelu_3(x)

            x = self.convT_3(x)
            x = self.batch_norm_4(x, training)
            x = self.leakyrelu_4(x)

            x = self.convT_4(x)
            x = self.batch_norm_5(x, training)
            x = self.leakyrelu_5(x)

            return self.convT_5(x)

然后我将权重另存为

generator=Generator()
path="abc.h5"
generator.save_weights(path)

然后我尝试按以下方式加载权重

generator=Generator()
generator=generator.load_weights("abc.h5")

但是我面对并收到一条错误消息

值错误:您正在尝试将包含X层的权重文件加载到0层的模型中

我已经研究了解决方案,但是每个答案都针对使用model.Sequential()定义的模型

0 个答案:

没有答案