我正在尝试Tensorflow示例:
import tensorflow as tf
x = tf.Variable(0, name='x')
model = tf.global_variables_initializer()
with tf.Session() as session:
for i in range(5):
session.run(model)
x = x + 1
print(session.run(x))
但输出是出乎意料的。我希望它输出:
1 1 1 1 1
BUt实际输出是:
1 2 3 4 5
这是为什么? session.run(model)每次都会初始化变量,这个语句是否正确?
答案 0 :(得分:2)
session.run(model)每次都会初始化变量
这是正确的。问题是每次x = x + 1
在图表中创建一个新添加项,这将解释您获得的结果。
第一次迭代后的图表:
第二次迭代后:
第三次迭代后:
第四次迭代后:
第五次迭代后:
我使用的代码,主要取自Yaroslav Bulatov在How can I list all Tensorflow variables a node depends on?中的回答:
import tensorflow as tf
import matplotlib.pyplot as plt
import networkx as nx
def children(op):
return set(op for out in op.outputs for op in out.consumers())
def get_graph():
"""Creates dictionary {node: {child1, child2, ..},..} for current
TensorFlow graph. Result is compatible with networkx/toposort"""
ops = tf.get_default_graph().get_operations()
return {op: children(op) for op in ops}
def plot_graph(G):
'''Plot a DAG using NetworkX'''
def mapping(node):
return node.name
G = nx.DiGraph(G)
nx.relabel_nodes(G, mapping, copy=False)
nx.draw(G, cmap = plt.get_cmap('jet'), with_labels = True)
plt.show()
x = tf.Variable(0, name='x')
model = tf.global_variables_initializer()
with tf.Session() as session:
for i in range(5):
session.run(model)
x = x + 1
print(session.run(x))
plot_graph(get_graph())
答案 1 :(得分:0)
看来你必须在循环中初始化它
import tensorflow as tf
for i in range(5):
X = tf.Variable(0, name='x')
model = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(model)
X += 1
print(X.eval())
答案 2 :(得分:0)
session.run(model)
每次都会初始化变量(因为它调用model = tf.global_variables_initializer()
),
但是对于循环中的每个条目,用于初始化的x
的值都会增加1
。例如,
i=0
的,x
初始化为其在那个实例中拥有的值,即0
。当i=1
时,x
已增加到1
,这是初始化程序将使用的值,依此类推。