我想恢复变量对象。也就是说,我希望在反序列化后有一个tensorflow.Variables
类型的对象。
我尝试使用MetaGraph。这是一个最小的例子。序列:
import tensorflow as tf
var = tf.Variable(101)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
tf.add_to_collection('var', var)
saver.save(sess, 'data/sess')
反序列化
import tensorflow as tf
with tf.Session() as sess:
saver = tf.train.import_meta_graph('data/sess.meta')
saver.restore(sess, 'data/sess')
var = tf.get_collection('var')[0]
print(var)
print(type(var))
# Output:
# Tensor("Variable:0", shape=(), dtype=int32_ref)
# <class 'tensorflow.python.framework.ops.Tensor'>
print(tf.get_collection('variables'))
# [<tensorflow.python.ops.variables.Variable object at 0x10edd1d30>]
test_var = tf.get_collection('variables')[0]
print(test_var.name)
# Variable:0
问题是tf.get_collection
返回tf.Tensor
对象,而不是
tf.Variable
。但我可以在tf.Variable
集合中看到variables
个对象。
恢复Variable
对象的正确方法是什么?
答案 0 :(得分:3)
使用您报告的代码,您正确地恢复了变量和张量。但是,我建议您使用更惯用的方法来创建和恢复变量,以便更好地管理图形元素。
首先,您应该使用tf.get_variable函数来创建和初始化变量。使用参数 name ,您应该将名称与变量相关联。这将允许您在还原步骤后检索它。
还原步骤已在您报告的代码中正确实现。如果要获取对变量的引用,则应再次使用tf.get_variable
函数,而不指定任何初始值设定项或形状。 TensorFlow范围管理器将识别您已经具有已选择名称的初始化变量,并将返回该变量。请参阅以下代码以更好地演示此过程:
import tensorflow as tf
with tf.Session() as sess:
saver = tf.train.import_meta_graph('data/sess.meta')
saver.restore(sess, 'data/sess')
# suppose that your variable is called "variable_101"
var = tf.get_variable("variable_101")
# var will represent your initialized variable