我遇到了函数tf.contrib.framework.init_from_checkpoint
的问题。它根本不起作用(很可能我做错了)。我精心设计了以下示例来演示行为:
import tensorflow as tf
model_name = "./my_model.ckp"
### MY MODEL IS COMPOSED BY 2 VARIABLES
with tf.variable_scope("A"):
A = tf.Variable([1, 2, 3], name="A1")
with tf.variable_scope("B"):
B = tf.Variable([4, 5, 6], name="B1")
# INITIALIZING AND SAVING THE MODEL
with tf.Session() as sess:
tf.global_variables_initializer().run(session=sess)
print(sess.run([A, B]))
saver = tf.train.Saver()
saver.save(sess, model_name)
#### CLEANING UP
tf.reset_default_graph()
### CREATING OTHER "MODEL"
with tf.variable_scope("C"):
A = tf.Variable([0, 0, 0], name="A1")
with tf.variable_scope("B"):
B = tf.Variable([0, 0, 0], name="B1")
# MAPPING THE VARIABLES FROM MY CHECKPOINT TO MY NEW SET OF VARIABLES
tf.contrib.framework.init_from_checkpoint(
model_name,
{"A/": "C/",
"B/": "B/"})
with tf.Session() as sess:
tf.global_variables_initializer().run(session=sess)
print(sess.run([A, B]))
输出是: [array([1,2,3],dtype = int32),array([4,5,6],dtype = int32)] - >这是预期和 [array([0,0,0,dtype = int32),array([0,0,0,dtype = int32)],这是不期望的。
发生了什么?
由于
答案 0 :(得分:1)
问题是您使用低级方法Variable
来创建变量,因此它不会存储在变量存储中。
在### CREATING OTHER "MODEL"
中,如果进行了以下更改:
with tf.variable_scope("C"):
A = tf.get_variable(name='A1', initializer=[0,0,0])
然后我测试了它可以从检查点成功恢复。