无论我指定了哪个变量,Tensorflow Saver都会还原所有变量

时间:2018-07-18 13:11:53

标签: python tensorflow

我正在尝试从Tensorflow图中保存和恢复变量的子集,以便丢弃我不需要的所有内容,并且它们的权重不会占用内存。将所需变量的列表或字典传递到tf.train.Saver的常见建议不起作用:无论如何,保护程序都将还原所有变量。

一个最小的工作示例:

import os
import tensorflow as tf
sess = tf.Session()
with sess.as_default():
    v1 =  tf.get_variable("v1", [5, 5, 3])
    v2 =  tf.get_variable("v2", [5, 5, 3])
    saver = tf.train.Saver([v2])
    initializer2 = tf.variables_initializer([v1, v2])
    sess.run(initializer2)
saver.save(sess, '/path/to/tf_model')

sess2 = tf.Session()
checkpoint = '/path/to/tf_model.meta'
saver.restore(sess2, tf.train.latest_checkpoint(os.path.dirname(checkpoint)))

with sess2.as_default(), sess2.graph.as_default():
    loaded_vars = tf.trainable_variables()

print(loaded_vars)

输出

[<tf.Variable 'v1:0' shape=(5, 5, 3) dtype=float32_ref>,
 <tf.Variable 'v2:0' shape=(5, 5, 3) dtype=float32_ref>]

尽管如此,print(saver._var_list)输出

[<tf.Variable 'v2:0' shape=(5, 5, 3) dtype=float32_ref>]

这是怎么了?

2 个答案:

答案 0 :(得分:1)

这就是您想要做的。请仔细检查代码。

要保存选定的变量

import tensorflow as tf

tf.reset_default_graph()

# =============================================================================
# to save
# =============================================================================

# create variables
v1 =  tf.get_variable(name="v1", initializer=[5, 5, 3])
v2 =  tf.get_variable(name="v2", initializer=[5, 5, 3])

# initialize variables
init_op = tf.global_variables_initializer()

# ops to save variable v2
saver = tf.train.Saver({"my_v2": v2})

with tf.Session() as sess:
    sess.run(init_op)
    save_path = saver.save(sess, './tf_vars/model.ckpt')
    print("Model saved in file: %s" % save_path)

'Output':
Model saved in file: ./tf_vars/model.ckpt

还原保存的变量

# =============================================================================
# to restore
# =============================================================================

# Create some variables.
v1 = tf.Variable(initial_value=[0, 0, 0], name="v1")
v2 = tf.Variable(initial_value=[0, 0, 0], name="v2")

# initialize variables
init_op = tf.global_variables_initializer()

#  ops to restore variable v2.
saver = tf.train.Saver({"my_v2": v2})

with tf.Session() as sess:
    sess.run(init_op)
    # Restore variables from disk.
    saver.restore(sess, './tf_vars/model.ckpt')
    print("v1: %s" % v1.eval())
    print("v2: %s" % v2.eval())

    print("V2 variable restored.")

'Output':
v1: [0 0 0]
v2: [5 5 3]
V2 variable restored.

答案 1 :(得分:0)

tf.trainable_variables()返回存储在图中的变量对象的列表。默认情况下,此处的变量v1和v2都将存储在图中。使用saver = tf.train.Saver([v2])时,仅保存变量v2,而没有保存v1的任何值。但是,变量v1仍然存在于您的图形中。这就是为什么您print(loaded_vars)可以看到所有变量的原因。您实际上可以使用这段代码检查变量是否具有值(已初始化)

uninitialized_vars = []
for var in tf.all_variables():
    try:
        sess.run(var)
    except tf.errors.FailedPreconditionError:
        uninitialized_vars.append(var)

print(uninitialized_vars)

希望这会有所帮助!

此外,如果您知道要初始化的变量,则不需要初始化所有变量(tf.global_variable)。

tf.reset_default_graph()
# Create some variables.
v1 = tf.get_variable("v1", [3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", [5], initializer = tf.zeros_initializer)

# Add ops to save and restore only `v2` using the name "v2"
saver = tf.train.Saver({"v2": v2})

# Use the saver object normally after that.
with tf.Session() as sess:
# Initialize v1 since the saver will not.
v1.initializer.run()
saver.restore(sess, "/tmp/model.ckpt")

print("v1 : %s" % v1.eval())
print("v2 : %s" % v2.eval())