def train_model(model, batch_gen, num_train_steps, weights_fld):
saver = tf.train.Saver() # defaults to saving all variables - in this case embed_matrix, nce_weight, nce_bias
initial_step = 0
with tf.Session() as sess:
**sess.run(tf.global_variables_initializer())**
ckpt = tf.train.get_checkpoint_state(os.path.dirname('checkpoints/checkpoint'))
# if that checkpoint exists, restore from checkpoint
***if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)***
在上面的代码中,非常清楚ow图如何尝试导入预训练参数(如果有的话)。(突出显示的部分) 因此,如果已经训练过的参数集(例如神经网络的权重集)为什么我们仍然需要用 tf.global_variables_initializer()初始化变量?
答案 0 :(得分:4)
如果在运行任何tensorflow图之前使用saver.restore(sess,file),不必须使用tf.global_variables_initializer()。
重写你的代码:
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(os.path.dirname('checkpoints/checkpoint'))
# if that checkpoint exists, restore from checkpoint
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
else :
sess.run(tf.global_variables_initializer())
您可以看到another example here
的完整工作示例