我想用tensorflow训练GAN,然后将生成器和鉴别器导出为tensorflow_hub模块。
为此:
-我用tensorflow定义GAN架构
-训练并保存检查点
-使用不同的标签(例如:
)创建module_spec
(set(), {'batch_size': 8, 'model': 'gen'})
({'bs8', 'gen'}, {'batch_size': 8, 'model': 'gen'})
({'bs8', 'disc'}, {'batch_size': 8, 'model': 'disc'})
-使用我在训练期间保存的checkpoint_path在tf_hub_path处以module_spec导出
然后,我可以使用以下命令加载生成器:
hub.Module(tf_hub_path, tags={"gen", "bs8"})
但是,当我尝试使用类似的命令来加载鉴别器时:
hub.Module(tf_hub_path, tags={"disc", "bs8"})
我得到了错误:
ValueError: Tensor discriminator/linear/bias is not found in b'/tf_hub/variables/variables' checkpoint {'generator/fc_noise/kernel': [2, 48], 'generator/fc_noise/bias': [48]}
因此,我得出结论,鉴别器中存在的变量未保存在磁盘上的模块中。我检查了我想象中的不同错误源:
然后,我想知道检查点是否正确地将所有变量保存在图形中。
checkpoint_path = tf.train.latest_checkpoint(self.model_dir)
inspect_list = tf.train.list_variables(checkpoint_path)
print(inspect_list)
[('disc_step_1/beta1_power', []),
('disc_step_1/beta2_power', []),
('discriminator/linear/bias', [1]),
('discriminator/linear/bias/d_opt', [1]),
('discriminator/linear/bias/d_opt_1', [1]),
('discriminator/linear/kernel', [3, 1]),
('discriminator/linear/kernel/d_opt', [3, 1]),
('discriminator/linear/kernel/d_opt_1', [3, 1]),
('gen_step/beta1_power', []),
('gen_step/beta2_power', []),
('generator/fc_noise/bias', [48]),
('generator/fc_noise/bias/g_opt', [48]),
('generator/fc_noise/bias/g_opt_1', [48]),
('generator/fc_noise/kernel', [2, 48]),
('generator/fc_noise/kernel/g_opt', [2, 48]),
('generator/fc_noise/kernel/g_opt_1', [2, 48]),
('global_step', []),
('global_step_disc', [])]
因此,我看到所有变量都已正确保存在检查点中。在磁盘的tf hub模块中,仅正确导出了与生成器相关的两个变量。
最后,我想我的错误来自:
module_spec.export(tf_hub_path, checkpoint_path=checkpoint_path)
仅考虑标签“ gen”才能从checkpoint_path导出变量。我还检查了变量的名称是否在module.variable_map和来自检查点路径的列表变量之间相对应。这是带有标签“ disc”的模块的变量映射:
print(module.variable_map)
{'discriminator/linear/bias': <tf.Variable 'module_8/discriminator/linear/bias:0' shape=(1,) dtype=float32>, 'discriminator/linear/kernel': <tf.Variable 'module_8/discriminator/linear/kernel:0' shape=(3, 1) dtype=float32>}
我有
感谢您的帮助
答案 0 :(得分:0)
即使我认为这不是解决问题的最干净方法,我还是找到了一种解决方法:
在调用hub.Module且没有标签时,默认情况下,代码的下一行定义模块:
(set(), {'batch_size': 8, 'model': 'gen'})
实际上,我意识到这组参数定义了通过module_spec.export导出的图形。解释了为什么在导入模块时我能够访问生成器的变量,但不能识别一个。
因此,我决定默认使用这组参数:
(set(), {'batch_size': 8, 'model': 'both'})
然后,在hub.create_module_spec调用的_module_fn方法中,我将生成器和鉴别器的输入(以及输出)定义为模型的输入(分别为输出)。因此,在导出module_spec时,我可以访问该图的所有变量。