恢复训练的张量流模型,编辑与节点关联的值并保存

时间:2017-08-17 11:42:02

标签: python machine-learning tensorflow deep-learning batch-normalization

我已经训练了一个具有张量流的模型,并在训练期间使用批量标准化。批量标准化要求用户传递一个名为is_training的布尔值,以设置模型是处于训练阶段还是测试阶段。

当训练模型时,is_training被设置为常数,如下所示

is_training = tf.constant(True, dtype=tf.bool, name='is_training')

我保存了经过训练的模型,文件包括checkpoint,.meta文件,.index文件和.data。我想恢复模型并使用它进行推理。 该模型无法重新训练。因此,我想恢复现有模型,将is_training的值设置为False,然后将模型保存回来。 如何编辑与该节点关联的布尔值,并再次保存模型?

1 个答案:

答案 0 :(得分:2)

您可以使用tf.train.import_meta_graphinput_map参数将图片张量重新映射到更新的值。

config = tf.ConfigProto(allow_soft_placement=True)
with tf.Session(config=config) as sess:
    # define the new is_training tensor
    is_training = tf.constant(False, dtype=tf.bool, name='is_training')

    # now import the graph using the .meta file of the checkpoint
    saver = tf.train.import_meta_graph(
    '/path/to/model.meta', input_map={'is_training:0':is_training})

    # restore all weights using the model checkpoint 
    saver.restore(sess, '/path/to/model')

    # save updated graph and variables values
    saver.save(sess, '/path/to/new-model-name')