我正在尝试从Tensorflow图中保存和恢复变量的子集,以便丢弃我不需要的所有内容,并且它们的权重不会占用内存。将所需变量的列表或字典传递到tf.train.Saver
的常见建议不起作用:无论如何,保护程序都将还原所有变量。
一个最小的工作示例:
import os
import tensorflow as tf
sess = tf.Session()
with sess.as_default():
v1 = tf.get_variable("v1", [5, 5, 3])
v2 = tf.get_variable("v2", [5, 5, 3])
saver = tf.train.Saver([v2])
initializer2 = tf.variables_initializer([v1, v2])
sess.run(initializer2)
saver.save(sess, '/path/to/tf_model')
sess2 = tf.Session()
checkpoint = '/path/to/tf_model.meta'
saver.restore(sess2, tf.train.latest_checkpoint(os.path.dirname(checkpoint)))
with sess2.as_default(), sess2.graph.as_default():
loaded_vars = tf.trainable_variables()
print(loaded_vars)
输出
[<tf.Variable 'v1:0' shape=(5, 5, 3) dtype=float32_ref>,
<tf.Variable 'v2:0' shape=(5, 5, 3) dtype=float32_ref>]
尽管如此,print(saver._var_list)
输出
[<tf.Variable 'v2:0' shape=(5, 5, 3) dtype=float32_ref>]
这是怎么了?
答案 0 :(得分:1)
这就是您想要做的。请仔细检查代码。
import tensorflow as tf
tf.reset_default_graph()
# =============================================================================
# to save
# =============================================================================
# create variables
v1 = tf.get_variable(name="v1", initializer=[5, 5, 3])
v2 = tf.get_variable(name="v2", initializer=[5, 5, 3])
# initialize variables
init_op = tf.global_variables_initializer()
# ops to save variable v2
saver = tf.train.Saver({"my_v2": v2})
with tf.Session() as sess:
sess.run(init_op)
save_path = saver.save(sess, './tf_vars/model.ckpt')
print("Model saved in file: %s" % save_path)
'Output':
Model saved in file: ./tf_vars/model.ckpt
# =============================================================================
# to restore
# =============================================================================
# Create some variables.
v1 = tf.Variable(initial_value=[0, 0, 0], name="v1")
v2 = tf.Variable(initial_value=[0, 0, 0], name="v2")
# initialize variables
init_op = tf.global_variables_initializer()
# ops to restore variable v2.
saver = tf.train.Saver({"my_v2": v2})
with tf.Session() as sess:
sess.run(init_op)
# Restore variables from disk.
saver.restore(sess, './tf_vars/model.ckpt')
print("v1: %s" % v1.eval())
print("v2: %s" % v2.eval())
print("V2 variable restored.")
'Output':
v1: [0 0 0]
v2: [5 5 3]
V2 variable restored.
答案 1 :(得分:0)
tf.trainable_variables()返回存储在图中的变量对象的列表。默认情况下,此处的变量v1和v2都将存储在图中。使用saver = tf.train.Saver([v2])
时,仅保存变量v2,而没有保存v1的任何值。但是,变量v1仍然存在于您的图形中。这就是为什么您print(loaded_vars)
可以看到所有变量的原因。您实际上可以使用这段代码检查变量是否具有值(已初始化)
uninitialized_vars = []
for var in tf.all_variables():
try:
sess.run(var)
except tf.errors.FailedPreconditionError:
uninitialized_vars.append(var)
print(uninitialized_vars)
希望这会有所帮助!
此外,如果您知道要初始化的变量,则不需要初始化所有变量(tf.global_variable)。
tf.reset_default_graph()
# Create some variables.
v1 = tf.get_variable("v1", [3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", [5], initializer = tf.zeros_initializer)
# Add ops to save and restore only `v2` using the name "v2"
saver = tf.train.Saver({"v2": v2})
# Use the saver object normally after that.
with tf.Session() as sess:
# Initialize v1 since the saver will not.
v1.initializer.run()
saver.restore(sess, "/tmp/model.ckpt")
print("v1 : %s" % v1.eval())
print("v2 : %s" % v2.eval())