Tensorflow:如何在训练时将模型保存在内存中

时间:2017-04-26 15:20:00

标签: python model tensorflow tensorflow-serving

我在张量流中的训练过程涉及在两个模型之间切换。 虽然使用tf.saver并从硬盘恢复模型非常耗时(在我的代码中,切换很频繁),因此,我想知道是否有办法将模型参数存储在内存中并恢复它们从记忆里。我的模型相当小,可以肯定存储在RAM中。 stackoverflow有一个答案。 Storing tensorflow models in memory但是,我不太明白这是如何工作的。有谁知道如何实现这一目标?谢谢。

1 个答案:

答案 0 :(得分:2)

您应该使用两个单独的图表:

g1 = tf.Graph()
g2 = tf.Graph()

with g1.as_default():
  # build your 1st model
  sess1 = tf.Session(graph=g1)
  # do some work with sess1 on g1
  sess1.run(...)

with g2.as_default():
  # build your 2nd model
  sess2 = tf.Session(graph=g2)
  # do some work with sess2 on g2
  sess2.run(...)

with g1.as_default():
  # do some more work with sess1 on g1 
  sess1.run(...)

with g2.as_default():
  # do some more work with sess2 on g2
  sess2.run(...)

sess1.close()
sess2.close()

您实际上并不需要with语句,一旦您创建了sess1sess2,您就可以使用它们,他们会参考正确的图形,但是当你还在习惯TF如何处理全局变量时,在你使用该图时设置默认图表可能是一种很好的形式。