为什么变量不在此程序中更新?

时间:2017-03-03 17:59:27

标签: python tensorflow

所以,我有这段代码:

graph = tf.Graph()

with graph.as_default():
    res  = tf.Variable(
        np.zeros((100, 100), dtype=np.float32)
    )

    mask = tf.placeholder(np.float32, (100, 100))
    res = tf.add(res, mask)

    init = tf.global_variables_initializer()

with tf.Session(graph=graph) as sess:
    sess.run(init)
    for i in range(100):
        x1, x2, y1, y2 = np.random.randint(0, 100, 4)
        x1, x2 = sorted((x1, x2))
        y1, y2 = sorted((y1, y2))

        c_mask = np.zeros((100, 100))
        c_mask[x1:x2, y1:y2] = 255

        new_pic = sess.run([res], feed_dict={mask:c_mask})[0]

fig = plt.figure()
ax = fig.add_subplot(111)

ax.imshow(new_pic.astype('uint8'), cmap='gray')
fig.show()

基本上,它(至少应该是)在黑场上绘制100个随机白色矩形。但我得到的是:

WTF

我不明白。看起来每次迭代res只会重置为再次出现黑色平板(图片上的矩形是最后绘制的,由坐标判断)。我不是以某种方式保存它,或者我做错了什么?

1 个答案:

答案 0 :(得分:1)

更改这两行:

res = tf.add(res, mask)
##
new_pic = sess.run([res], feed_dict={mask:c_mask})[0]

update_op = res.assign(tf.add(res, mask))
##
new_pic = sess.run([res, update_op], feed_dict={mask:c_mask})[0]

我应该用Variable.assign()修改张量。

虽然您打算更新张量“res”,但第一行不更新它,而只创建另一个张量,名称“res”被分配给新张量。因此原始资源永远不会更新。