将新的权重加载到会话中

时间:2019-02-25 17:12:40

标签: python tensorflow machine-learning

是否可以将权重加载到新会话?我想通过从检查点还原权重来获取一个模型的权重,以数组的形式获取变量,并将其提供给新的会话。这可能吗?这是我目前正在尝试的代码:

def __init__(self, ):
    self.all_values = []

def importer(self):
    new_graph = tf.Graph()

    ckpt = tf.train.get_checkpoint_state('./models/Control')
    with tf.Session(graph=new_graph) as sess1:
        # import control graph and ckpt with trained values
        saver = tf.train.import_meta_graph("./models/control-graph.meta")
        saver.restore(sess1, ckpt.model_checkpoint_path)
        self.all_values = sess1.run(tf.global_variables())
        all_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
        sess1.close()

    all_var_names = []
    for i in range(len(all_vars)):
        all_var_names.append(all_vars[i].name)
    np.save("./tmp/all_var_names", all_var_names)

def loader(self):

    all_var_names = np.load("./tmp/all_var_names.npy")
    new_graph = tf.Graph()
    ckpt = tf.train.get_checkpoint_state('./models/TestRun1')

    with tf.Session(graph=new_graph) as sess:
        saver = tf.train.import_meta_graph("./models/control-graph.meta")

        all_ops = []
        all_names = []
        for i in range(len(all_var_names)):
            all_names.append(all_var_names[i].split(":")[0])
            all_ops.append(new_graph.get_operation_by_name(all_names[i]).outputs[0])

        test = zip(all_ops, self.all_values)

        d = dict(test)

        #sess.run(tf.global_variables_initializer())
        sess.run(all_ops, feed_dict=d)

        print(sess.run(tf.trainable_variables()))

        saver.save(sess, './models/Test/output.cptk')

我尝试过的所有替代方法都使我接近,但从不满足我的要求。我不能使用占位符?因为我尝试根据导入的图形和ckpt值动态地执行此操作。我只是想将张量值加载到会话中。我尝试过恢复未经训练的会话,并使用优化的权重更新变量,但没有运气,有什么建议可取吗?

0 个答案:

没有答案