TensorFlow,将两个检查点值合并为一个并还原

时间:2019-03-25 13:57:58

标签: tensorflow

我有两个具有相同架构的模型(A和B),A和B都具有相同的变量名称和模型设置,例如

['A1\B1\C1', 'A2\B2\C2', 'A3\B3\C3']

我有A和B的检查点文件,并且我想将A中的['A1\B1\C1', 'A2\B2\C2']与B int中的'A3\B3\C3'合并到检查点文件中,并将其还原到模型A中。与saver.restor()

2 个答案:

答案 0 :(得分:1)

您可以使用init_from_checkpoint来实现。定义当前模型后,创建工作分配图。

dir = 'path_to_A_and_B_checkpoint_files'
vars_to_load = [i[0] for i in tf.train.list_variables(dir)]
assignment_map = {variable.op.name: variable for variable in tf.global_variables() if variable.op.name in vars_to_restore}

这将创建一个字典,将当前图形中的变量作为键,并将检查点中的变量作为值

tf.train.init_from_checkpoint(dir, assignment_map)
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  #do_usual_stuff

此函数放在声明会话之前,并用saver.restore代替

答案 1 :(得分:1)

独自回答我的问题。

import tensorflow as tf
from tensorflow.python import pywrap_tensorflow

def load_weights(ckpt_path, prefix_list):
    vars_weights = {}
    reader = pywrap_tensorflow.NewCheckpointReader(ckpt_path)
    var_to_shape_map = reader.get_variable_to_shape_map()
    for key in sorted(var_to_shape_map):
        for _pref in prefix_list:
            if key.startswith(_pref):
                vars_weights[key+':0'] = reader.get_tensor(key)
    return vars_weights

# Build model
...
# Init variables
sess.run(tf.global_variables_initializer())
# Restore model
saver.restore(sess, load_dir_A)

prefix = ['A3\B3\C3']
# Get weights from ckpt of B
B_weights = load_weights(load_dir_B, prefix)
# Assign weights from B to A
assign_ops = [tf.assign(tf.get_default_graph().get_tensor_by_name(_name, _value) 
                                for _name, _value in opponent_weights.items()]
sess.run(assign_ops)