从检查点加载模型失败?

时间:2019-07-09 08:33:10

标签: tensorflow2.0 generative-adversarial-network

https://www.tensorflow.org/beta/tutorials/generative/dcgan 我正在按照DCGAN的本教程进行操作,并尝试从保存的检查点还原模型。但是当我尝试加载模型时,它给了我一个错误:

AssertionError: Nothing except the root object matched a checkpointed value. Typically this means that the checkpoint does not match the Python program. The following objects have no matching checkpointed value: [<tensorflow.python.keras.layers.advanced_activations.LeakyReLU object at 0x7f05b545fc88>, <tf.Variable 'conv2d_transpose_5/kernel:0' shape=(3, 3, 1, 64) dtype=float32, numpy=.......

我试图仅保留生成器部分来修改检查点,因为这是我从这里所需的全部。之后,我做

latest = tf.train.latest_checkpoint(checkpoint_dir)
gen_mod = make_generator_model() #This is already defined in the code
gen_mod.load_weights(latest)
sample = gen_mod(noise,training=False)

这给了我错误。有没有一种方法可以只加载发电机零件? 我想要的是能够使用给定检查点的生成器模型生成图像。

1 个答案:

答案 0 :(得分:0)

正如您所说,DCGAN教程创建了两个模型的快照(生成器和鉴别器),load_weights必须选择仅与生成器相关的权重,并且缺少执行此操作的上下文

您可以restore检查点,然后从那里访问生成器(或鉴别器),而不是尝试将权重直接加载到模型中:

# from the DCGAN tutorial
checkpoint = tf.train.Checkpoint(
    generator_optimizer=generator_optimizer,
    discriminator_optimizer=discriminator_optimizer,
    generator=generator,
    discriminator=discriminator,
)

latest = tf.train.latest_checkpoint(checkpoint_dir)
checkpoint.restore(latest)

# classify an image
checkpoint.discriminator(training_images[0:2])

# generate an image
checkpoint.generator(noise)