查看tf.reset_default_graph()的定义,在我看来tf.reset_default_graph()
在我使用tf.Session()
时不会重置图中的张量值(例如权重)循环:
import tensorflow as tf
for i in range(X):
#import a pretrained graph
with tf.Session() as sess:
saver = tf.train.import_meta_graph(META)
saver.restore(sess,MODEL)
sess.run(TENSORS_IN_PRETRAINED_MDOEL)
print sess.run(...)
#Do sth with weghts X times here
sess.run(tf.assign([v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if v.name == OLD_VALUE ][0], NEW_VALUE))
tf.reset_default_graph()
如何在嵌套循环上完全破坏tf.graph?
P.s:我确实尝试new_graph()
但没有成功。
答案 0 :(得分:0)
@ de1的评论是正确的。我认为您混淆了会话和图形。
tf.add
,tf.matmul
,tf.nn.conv2d
之类的操作时,仅将节点放在图形中。tf.reset_default_graph()
时,将清除默认图中的这些节点。您可以通过graph1 = tf.Graph()
graph2 = tf.Graph()
绘制多个图形。您可以通过执行with graph2.as_default()
tf.global_variables_initializer()
。要回答您的问题: 您可以尝试:
for i in range(X):
graphX = tf.Graph() # every time you overwrite it
with graphX.as_default():
with tf.Session(graph=graphX) as sess: #make sure which graph use
...
希望有帮助