修改导入的Tensorflow图表

时间:2016-12-14 14:32:11

标签: python tensorflow

我创建了一个带有AdamOptimizer的图表,然后我用tf.train.Saver().save(session, "model_name")保存了

经过一段时间的训练后,我能够在不同的会话中导入整个图表和变量并继续训练

saver = tf.train.import_meta_graph("model_name")
saver.restore(session, "model_name")

我想要做的是,在导入图形+变量之后和恢复优化之前,更改AdamOptimizer的learning_rate。那可能吗?

编辑:这样做的一种方法是将学习率定义为占位符,并且每次都提供不同的值。但是我们假设图形已经被保存而没有为了论证而这样做。

1 个答案:

答案 0 :(得分:1)

我认为你可以用placeholder替换learning_rate,即

learning_rate = tf.placeholder(tf.float32,shape=(),name="learing_rate")
train_op = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(your_loss_tensor, name="train_op")

恢复图表后,使用

获取所有与train_oplearning_rate相关的所有操作和张量
train_op = graph.get_operation_by_name("train_op")
learning_rate = graph.get_tensor_by_name("learning_rate:0")

并运行火车

sess.run(train_op, feed_dict={learning_rate: whatever_you_what})

更新: 如果您想更改已保存图表的某些输入,请参阅this