正确使用tf.reset_default_graph()来破坏图形

时间:2018-01-06 09:47:55

标签: tensorflow

查看tf.reset_default_graph()的定义,在我看来tf.reset_default_graph()在我使用tf.Session()时不会重置图中的张量值(例如权重)循环:

import tensorflow as tf

for i in range(X):
  #import a pretrained graph
  with tf.Session() as sess:
    saver = tf.train.import_meta_graph(META)
    saver.restore(sess,MODEL)
    sess.run(TENSORS_IN_PRETRAINED_MDOEL)
    print sess.run(...)
    #Do sth with weghts X times here
    sess.run(tf.assign([v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if v.name == OLD_VALUE ][0], NEW_VALUE))
  tf.reset_default_graph()

如何在嵌套循环上完全破坏tf.graph?

P.s:我确实尝试new_graph()但没有成功。

1 个答案:

答案 0 :(得分:0)

@ de1的评论是正确的。我认为您混淆了会话和图形。

  1. 在图形中绘制tf.addtf.matmultf.nn.conv2d之类的操作时,仅将节点放在图形中。
  2. 执行tf.reset_default_graph()时,将清除默认图中的这些节点。您可以通过graph1 = tf.Graph() graph2 = tf.Graph()绘制多个图形。您可以通过执行with graph2.as_default()
  3. 选择要启动会话的图形
  4. 诸如权重和偏差之类的变量仅在会话中存在。这就是为什么您总是在会话中执行tf.global_variables_initializer()

要回答您的问题: 您可以尝试:

for i in range(X):
     graphX = tf.Graph() # every time you overwrite it
     with graphX.as_default():
          with tf.Session(graph=graphX) as sess: #make sure which graph use
               ...

希望有帮助