当我使用一些经过预先训练的模型参数时,为什么我需要使用tf.global_variables_initializer()来初始化变量?

时间:2017-05-03 19:56:32

标签: tensorflow

def train_model(model, batch_gen, num_train_steps, weights_fld):
    saver = tf.train.Saver() # defaults to saving all variables - in this case embed_matrix, nce_weight, nce_bias

initial_step = 0
with tf.Session() as sess:
    **sess.run(tf.global_variables_initializer())**
    ckpt = tf.train.get_checkpoint_state(os.path.dirname('checkpoints/checkpoint'))
    # if that checkpoint exists, restore from checkpoint
    ***if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)***

在上面的代码中,非常清楚ow图如何尝试导入预训练参数(如果有的话)。(突出显示的部分) 因此,如果已经训练过的参数集(例如神经网络的权重集)为什么我们仍然需要用 tf.global_variables_initializer()初始化变量?

1 个答案:

答案 0 :(得分:4)

如果在运行任何tensorflow图之前使用saver.restore(sess,file),必须使用tf.global_variables_initializer()。

重写你的代码:

with tf.Session() as sess:

    ckpt = tf.train.get_checkpoint_state(os.path.dirname('checkpoints/checkpoint'))

    # if that checkpoint exists, restore from checkpoint
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
    else :
        sess.run(tf.global_variables_initializer())

您可以看到another example here

的完整工作示例