tf.train.init_from_checkpoints()与Saver.restore()有什么区别?

时间:2020-07-20 00:11:50

标签: python tensorflow keras

我曾经使用saver.savesaver.restore保存和恢复检查点:

with train_graph.as_default():
    saver = tf.train.Saver()
with tf.Session(graph=train_graph) as sess:
    saver.restore(sess, tf.train.latest_checkpoint(embedding_checkpoint))

在某些示例中,我看到了tf.train.init_from_checkpoints()

tvars = tf.trainable_variables()
initialized_variable_names = {}
(assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)

tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

它们的效果是否相等?

我检查了modeling.get_assignment_map_from_checkpoint函数,它根据名称将当前图下tf.trainable_variables()中的变量与tf.train.list_variables(init_checkpoint)中的变量进行匹配。

我有一个问题。如果存储张量时未指定其名称,例如:

output = tf.layers.dense(a,hidden_size)

还原后,图形相同,但我添加了图层名称:

output = tf.layers.dense(a,hidden_size, name = 'dense1')

两个API是否可以恢复该层的值?看来可以指定tf.train.init_from_checkpoints()的参数Assignment_map来保证这一点。 saver.restore()如何做到这一点?

0 个答案:

没有答案