我的模型生成器定义为
class Generator(tf.keras.Model):
def __init__(self, channels=3, method='transpose'):
super(Generator, self).__init__()
self.channels = channels
self.method = method
self.dense = tf.keras.layers.Dense(512 * 8 * 8, use_bias=False)
self.reshape = tf.keras.layers.Reshape((8, 8, 512))
self.convT_1 = tf.keras.layers.Conv2DTranspose(256, (3, 3), padding='same', use_bias=False)
self.convT_2 = tf.keras.layers.Conv2DTranspose(256, (3, 3), strides=(2, 2), padding='same', use_bias=False)
self.convT_3 = tf.keras.layers.Conv2DTranspose(128, (3, 3), strides=(2, 2), padding='same', use_bias=False)
self.convT_4 = tf.keras.layers.Conv2DTranspose(128,(3, 3), strides=(2, 2), padding='same', use_bias=False)
self.convT_5 = tf.keras.layers.Conv2DTranspose(self.channels, (4, 4), strides=(2, 2), padding='same', use_bias=False, activation='tanh')
self.batch_norm_1 = tf.keras.layers.BatchNormalization()
self.batch_norm_2 = tf.keras.layers.BatchNormalization()
self.batch_norm_3 = tf.keras.layers.BatchNormalization()
self.batch_norm_4 = tf.keras.layers.BatchNormalization()
self.batch_norm_5 = tf.keras.layers.BatchNormalization()
self.leakyrelu_1 = tf.keras.layers.LeakyReLU()
self.leakyrelu_2 = tf.keras.layers.LeakyReLU()
self.leakyrelu_3 = tf.keras.layers.LeakyReLU()
self.leakyrelu_4 = tf.keras.layers.LeakyReLU()
self.leakyrelu_5 = tf.keras.layers.LeakyReLU()
def call(self, inputs, training=True):
x = self.dense(inputs)
x = self.batch_norm_1(x, training)
x = self.leakyrelu_1(x)
x = self.reshape(x)
x = self.convT_1(x)
x = self.batch_norm_2(x, training)
x = self.leakyrelu_2(x)
x = self.convT_2(x)
x = self.batch_norm_3(x, training)
x = self.leakyrelu_3(x)
x = self.convT_3(x)
x = self.batch_norm_4(x, training)
x = self.leakyrelu_4(x)
x = self.convT_4(x)
x = self.batch_norm_5(x, training)
x = self.leakyrelu_5(x)
return self.convT_5(x)
然后我将权重另存为
generator=Generator()
path="abc.h5"
generator.save_weights(path)
然后我尝试按以下方式加载权重
generator=Generator()
generator=generator.load_weights("abc.h5")
但是我面对并收到一条错误消息
值错误:您正在尝试将包含X层的权重文件加载到0层的模型中
我已经研究了解决方案,但是每个答案都针对使用model.Sequential()定义的模型