我已经训练了一个具有张量流的模型,并在训练期间使用批量标准化。批量标准化要求用户传递一个名为is_training
的布尔值,以设置模型是处于训练阶段还是测试阶段。
当训练模型时,is_training
被设置为常数,如下所示
is_training = tf.constant(True, dtype=tf.bool, name='is_training')
我保存了经过训练的模型,文件包括checkpoint,.meta文件,.index文件和.data。我想恢复模型并使用它进行推理。
该模型无法重新训练。因此,我想恢复现有模型,将is_training
的值设置为False
,然后将模型保存回来。
如何编辑与该节点关联的布尔值,并再次保存模型?
答案 0 :(得分:2)
您可以使用tf.train.import_meta_graph
的input_map
参数将图片张量重新映射到更新的值。
config = tf.ConfigProto(allow_soft_placement=True)
with tf.Session(config=config) as sess:
# define the new is_training tensor
is_training = tf.constant(False, dtype=tf.bool, name='is_training')
# now import the graph using the .meta file of the checkpoint
saver = tf.train.import_meta_graph(
'/path/to/model.meta', input_map={'is_training:0':is_training})
# restore all weights using the model checkpoint
saver.restore(sess, '/path/to/model')
# save updated graph and variables values
saver.save(sess, '/path/to/new-model-name')