我在tensorflow中训练生成对抗网络(GAN),基本上我们有两个不同的网络,每个网络都有自己的优化器。
self.G, self.layer = self.generator(self.inputCT,batch_size_tf)
self.D, self.D_logits = self.discriminator(self.GT_1hot)
...
self.g_optim = tf.train.MomentumOptimizer(self.learning_rate_tensor, 0.9).minimize(self.g_loss, global_step=self.global_step)
self.d_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5) \
.minimize(self.d_loss, var_list=self.d_vars)
问题在于我首先训练其中一个网络(g),然后,我想一起训练g和d。但是,当我调用加载函数时:
self.sess.run(tf.initialize_all_variables())
self.sess.graph.finalize()
self.load(self.checkpoint_dir)
def load(self, checkpoint_dir):
print(" [*] Reading checkpoints...")
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
self.saver.restore(self.sess, ckpt.model_checkpoint_path)
return True
else:
return False
我有这样的错误(有更多的追溯):
Tensor name "beta2_power" not found in checkpoint files checkpoint/MR2CT.model-96000
我可以恢复g网络并继续使用该功能进行训练,但是当我想从头开始创建d,并且从存储的模型中获取g时,我会遇到该错误。
答案 0 :(得分:32)
要恢复变量的子集,您必须创建一个新的tf.train.Saver
并将其传递给特定的变量列表,以便在可选的var_list
参数中进行恢复。
默认情况下,tf.train.Saver
将创建操作,以便(i)在您调用saver.save()
时保存图表中的每个变量,以及(ii)查找(按名称)给定检查点中的每个变量致电saver.restore()
。虽然这适用于大多数常见方案,但您必须提供更多信息以使用变量的特定子集:
如果您只想恢复变量的子集,可以通过调用tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=G_NETWORK_PREFIX)
来获取这些变量的列表,假设您将“g”网络放在一个公共with tf.name_scope(G_NETWORK_PREFIX):
中}或tf.variable_scope(G_NETWORK_PREFIX):
阻止。然后,您可以将此列表传递给tf.train.Saver
构造函数。
如果要恢复变量的子集和/或检查点中的变量具有不同的名称,则可以将字典作为var_list
参数传递。默认情况下,检查点中的每个变量都与键关联,后者是其tf.Variable.name
属性的值。如果目标图中的名称不同(例如,因为您添加了作用域前缀),则可以指定将字符串键(在检查点文件中)映射到tf.Variable
对象(在目标图中)的字典。
答案 1 :(得分:1)
受@mrry的启发,我提出了解决这个问题的方法。 为了说清楚,当模型建立在预先训练的模型上时,我将问题表述为从检查点恢复变量的子集。 首先,我们应该使用库undecidable中的print_tensors_in_checkpoint_file函数,或者只是通过以下方式提取此函数:
from tensorflow.python import pywrap_tensorflow
def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors):
varlist=[]
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
if all_tensors:
var_to_shape_map = reader.get_variable_to_shape_map()
for key in sorted(var_to_shape_map):
varlist.append(key)
return varlist
varlist=print_tensors_in_checkpoint_file(file_name=the path of the ckpt file,all_tensors=True,tensor_name=None)
然后我们使用tf.get_collection()就像@mrry一样:
variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
最后,我们可以通过以下方式初始化保护程序:
saver = tf.train.Saver(variable[:len(varlist)])
完整版可以在我的github上找到:inspect_checkpoint
在我的情况下,新变量被添加到模型的末尾,所以我可以简单地使用[:length()]来识别所需的变量,对于更复杂的情况,你可能需要做一些手工 - 对齐工作或编写一个简单的字符串匹配函数来确定所需的变量。
答案 2 :(得分:1)
当从检查点恢复部分变量时,我遇到了类似的问题,并且新模型中不存在某些已保存的变量。 受@Lidong回答的启发我修改了一点阅读功能:
def get_tensors_in_checkpoint_file(file_name,all_tensors=True,tensor_name=None):
varlist=[]
var_value =[]
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
if all_tensors:
var_to_shape_map = reader.get_variable_to_shape_map()
for key in sorted(var_to_shape_map):
varlist.append(key)
var_value.append(reader.get_tensor(key))
else:
varlist.append(tensor_name)
var_value.append(reader.get_tensor(tensor_name))
return (varlist, var_value)
并添加了加载功能:
def build_tensors_in_checkpoint_file(loaded_tensors):
full_var_list = list()
# Loop all loaded tensors
for i, tensor_name in enumerate(loaded_tensors[0]):
# Extract tensor
try:
tensor_aux = tf.get_default_graph().get_tensor_by_name(tensor_name+":0")
except:
print('Not found: '+tensor_name)
full_var_list.append(tensor_aux)
return full_var_list
然后您可以使用以下命令加载所有常见变量:
CHECKPOINT_NAME = path to save file
restored_vars = get_tensors_in_checkpoint_file(file_name=CHECKPOINT_NAME)
tensors_to_load = build_tensors_in_checkpoint_file(restored_vars)
loader = tf.train.Saver(tensors_to_load)
loader.restore(sess, CHECKPOINT_NAME)
编辑:我正在使用tensorflow 1.2
答案 3 :(得分:0)
您可以创建一个单独的tf.train.Saver()
实例,并将var_list
参数设置为要还原的变量。
并创建一个单独的实例来保存变量