函数tf.contrib.framework.init_from_checkpoint无法正常工作

时间:2018-01-04 14:18:23

标签: tensorflow

我遇到了函数tf.contrib.framework.init_from_checkpoint的问题。它根本不起作用(很可能我做错了)。我精心设计了以下示例来演示行为:

import tensorflow as tf
model_name = "./my_model.ckp"

### MY MODEL IS COMPOSED BY 2 VARIABLES
with tf.variable_scope("A"):
    A = tf.Variable([1, 2, 3], name="A1")

with tf.variable_scope("B"):
    B = tf.Variable([4, 5, 6], name="B1")

# INITIALIZING AND SAVING THE MODEL    
with tf.Session() as sess:
    tf.global_variables_initializer().run(session=sess)
    print(sess.run([A, B]))

    saver = tf.train.Saver()
    saver.save(sess, model_name)

#### CLEANING UP
tf.reset_default_graph()


### CREATING OTHER "MODEL"
with tf.variable_scope("C"):
    A = tf.Variable([0, 0, 0], name="A1")

with tf.variable_scope("B"):
    B = tf.Variable([0, 0, 0], name="B1")

# MAPPING THE VARIABLES FROM MY CHECKPOINT TO MY NEW SET OF VARIABLES
tf.contrib.framework.init_from_checkpoint(
    model_name,
    {"A/": "C/", 
    "B/": "B/"})

with tf.Session() as sess:
    tf.global_variables_initializer().run(session=sess)
    print(sess.run([A, B]))

输出是: [array([1,2,3],dtype = int32),array([4,5,6],dtype = int32)] - >这是预期和 [array([0,0,0,dtype = int32),array([0,0,0,dtype = int32)],这是不期望的。

发生了什么?

由于

1 个答案:

答案 0 :(得分:1)

问题是您使用低级方法Variable来创建变量,因此它不会存储在变量存储中。

### CREATING OTHER "MODEL"中,如果进行了以下更改:

with tf.variable_scope("C"): A = tf.get_variable(name='A1', initializer=[0,0,0])

然后我测试了它可以从检查点成功恢复。