在Tensorflow中恢复变量子集

时间:2017-01-12 19:11:01

标签: python tensorflow

我在tensorflow中训练生成对抗网络(GAN),基本上我们有两个不同的网络,每个网络都有自己的优化器。

self.G, self.layer = self.generator(self.inputCT,batch_size_tf)
self.D, self.D_logits = self.discriminator(self.GT_1hot)

...

self.g_optim = tf.train.MomentumOptimizer(self.learning_rate_tensor, 0.9).minimize(self.g_loss, global_step=self.global_step)

self.d_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5) \
                      .minimize(self.d_loss, var_list=self.d_vars)

问题在于我首先训练其中一个网络(g),然后,我想一起训练g和d。但是,当我调用加载函数时:

self.sess.run(tf.initialize_all_variables())
self.sess.graph.finalize()

self.load(self.checkpoint_dir)

def load(self, checkpoint_dir):
    print(" [*] Reading checkpoints...")

    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    if ckpt and ckpt.model_checkpoint_path:
        ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
        self.saver.restore(self.sess, ckpt.model_checkpoint_path)
        return True
    else:
        return False

我有这样的错误(有更多的追溯):

Tensor name "beta2_power" not found in checkpoint files checkpoint/MR2CT.model-96000

我可以恢复g网络并继续使用该功能进行训练,但是当我想从头开始创建d,并且从存储的模型中获取g时,我会遇到该错误。

4 个答案:

答案 0 :(得分:32)

要恢复变量的子集,您必须创建一个新的tf.train.Saver并将其传递给特定的变量列表,以便在可选的var_list参数中进行恢复。

默认情况下,tf.train.Saver将创建操作,以便(i)在您调用saver.save()时保存图表中的每个变量,以及(ii)查找(按名称)给定检查点中的每个变量致电saver.restore()。虽然这适用于大多数常见方案,但您必须提供更多信息以使用变量的特定子集:

  1. 如果您只想恢复变量的子集,可以通过调用tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=G_NETWORK_PREFIX)来获取这些变量的列表,假设您将“g”网络放在一个公共with tf.name_scope(G_NETWORK_PREFIX):中}或tf.variable_scope(G_NETWORK_PREFIX):阻止。然后,您可以将此列表传递给tf.train.Saver构造函数。

  2. 如果要恢复变量的子集和/或检查点中的变量具有不同的名称,则可以将字典作为var_list参数传递。默认情况下,检查点中的每个变量都与关联,后者是其tf.Variable.name属性的值。如果目标图中的名称不同(例如,因为您添加了作用域前缀),则可以指定将字符串键(在检查点文件中)映射到tf.Variable对象(在目标图中)的字典。

答案 1 :(得分:1)

受@mrry的启发,我提出了解决这个问题的方法。 为了说清楚,当模型建立在预先训练的模型上时,我将问题表述为从检查点恢复变量的子集。 首先,我们应该使用库undecidable中的print_tensors_in_checkpoint_file函数,或者只是通过以下方式提取此函数:

from tensorflow.python import pywrap_tensorflow
def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors):
    varlist=[]
    reader = pywrap_tensorflow.NewCheckpointReader(file_name)
    if all_tensors:
      var_to_shape_map = reader.get_variable_to_shape_map()
      for key in sorted(var_to_shape_map):
        varlist.append(key)
    return varlist
varlist=print_tensors_in_checkpoint_file(file_name=the path of the ckpt file,all_tensors=True,tensor_name=None)

然后我们使用tf.get_collection()就像@mrry一样:

variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)

最后,我们可以通过以下方式初始化保护程序:

saver = tf.train.Saver(variable[:len(varlist)])

完整版可以在我的github上找到:inspect_checkpoint

在我的情况下,新变量被添加到模型的末尾,所以我可以简单地使用[:length()]来识别所需的变量,对于更复杂的情况,你可能需要做一些手工 - 对齐工作或编写一个简单的字符串匹配函数来确定所需的变量。

答案 2 :(得分:1)

当从检查点恢复部分变量时,我遇到了类似的问题,并且新模型中不存在某些已保存的变量。 受@Lidong回答的启发我修改了一点阅读功能:

def get_tensors_in_checkpoint_file(file_name,all_tensors=True,tensor_name=None):
varlist=[]
var_value =[]
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
if all_tensors:
  var_to_shape_map = reader.get_variable_to_shape_map()
  for key in sorted(var_to_shape_map):
    varlist.append(key)
    var_value.append(reader.get_tensor(key))
else:
    varlist.append(tensor_name)
    var_value.append(reader.get_tensor(tensor_name))
return (varlist, var_value)

并添加了加载功能:

def build_tensors_in_checkpoint_file(loaded_tensors):
full_var_list = list()
# Loop all loaded tensors
for i, tensor_name in enumerate(loaded_tensors[0]):
    # Extract tensor
    try:
        tensor_aux = tf.get_default_graph().get_tensor_by_name(tensor_name+":0")
    except:
        print('Not found: '+tensor_name)
    full_var_list.append(tensor_aux)
return full_var_list

然后您可以使用以下命令加载所有常见变量:

CHECKPOINT_NAME = path to save file
restored_vars  = get_tensors_in_checkpoint_file(file_name=CHECKPOINT_NAME)
tensors_to_load = build_tensors_in_checkpoint_file(restored_vars)
loader = tf.train.Saver(tensors_to_load)
loader.restore(sess, CHECKPOINT_NAME)

编辑:我正在使用tensorflow 1.2

答案 3 :(得分:0)

您可以创建一个单独的tf.train.Saver()实例,并将var_list参数设置为要还原的变量。 并创建一个单独的实例来保存变量