我曾经使用saver.save
和saver.restore
保存和恢复检查点:
with train_graph.as_default():
saver = tf.train.Saver()
with tf.Session(graph=train_graph) as sess:
saver.restore(sess, tf.train.latest_checkpoint(embedding_checkpoint))
在某些示例中,我看到了tf.train.init_from_checkpoints()
:
tvars = tf.trainable_variables()
initialized_variable_names = {}
(assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
它们的效果是否相等?
我检查了modeling.get_assignment_map_from_checkpoint函数,它根据名称将当前图下tf.trainable_variables()中的变量与tf.train.list_variables(init_checkpoint)
中的变量进行匹配。
我有一个问题。如果存储张量时未指定其名称,例如:
output = tf.layers.dense(a,hidden_size)
还原后,图形相同,但我添加了图层名称:
output = tf.layers.dense(a,hidden_size, name = 'dense1')
两个API是否可以恢复该层的值?看来可以指定tf.train.init_from_checkpoints()
的参数Assignment_map来保证这一点。 saver.restore()
如何做到这一点?