如何恢复变量对象

时间:2017-03-11 10:33:45

标签: tensorflow

我想恢复变量对象。也就是说,我希望在反序列化后有一个tensorflow.Variables类型的对象。

我尝试使用MetaGraph。这是一个最小的例子。序列:

import tensorflow as tf

var = tf.Variable(101)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    tf.add_to_collection('var', var)
    saver.save(sess, 'data/sess')

反序列化

import tensorflow as tf

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('data/sess.meta')
    saver.restore(sess, 'data/sess')

    var = tf.get_collection('var')[0]
    print(var)
    print(type(var))
    # Output:
    # Tensor("Variable:0", shape=(), dtype=int32_ref)
    # <class 'tensorflow.python.framework.ops.Tensor'>

    print(tf.get_collection('variables'))
    # [<tensorflow.python.ops.variables.Variable object at 0x10edd1d30>]

    test_var = tf.get_collection('variables')[0]
    print(test_var.name)
    # Variable:0

问题是tf.get_collection返回tf.Tensor对象,而不是 tf.Variable。但我可以在tf.Variable集合中看到variables个对象。

恢复Variable对象的正确方法是什么?

1 个答案:

答案 0 :(得分:3)

使用您报告的代码,您正确地恢复了变量和张量。但是,我建议您使用更惯用的方法来创建和恢复变量,以便更好地管理图形元素。

首先,您应该使用tf.get_variable函数来创建和初始化变量。使用参数 name ,您应该将名称与变量相关联。这将允许您在还原步骤后检索它。

还原步骤已在您报告的代码中正确实现。如果要获取对变量的引用,则应再次使用tf.get_variable函数,而不指定任何初始值设定项或形状。 TensorFlow范围管理器将识别您已经具有已选择名称的初始化变量,并将返回该变量。请参阅以下代码以更好地演示此过程:

import tensorflow as tf

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('data/sess.meta')
    saver.restore(sess, 'data/sess')

    # suppose that your variable is called "variable_101"
    var = tf.get_variable("variable_101")

    # var will represent your initialized variable