在二维tf.Variable中使用tf.scatter_update

时间:2016-10-29 15:19:53

标签: python matrix tensorflow

我' m跟随此Manipulating matrix elements in tensorflow。使用tf.scatter_update。但我的问题是: 如果我的tf.Variable是2D,会发生什么?让我们说:

a = tf.Variable(initial_value=[[0, 0, 0, 0],[0, 0, 0, 0]])

我如何更新例如每一行的第一个元素并将其赋值为1?

我试过像

这样的东西
for line in range(2):
    sess.run(tf.scatter_update(a[line],[0],[1]))

但它失败了(我期待那样)并且给了我错误:

  

TypeError:输入' ref' ' ScatterUpdate'操作需要l值输入

我该如何解决这类问题?

`

2 个答案:

答案 0 :(得分:12)

在tensorflow中,您无法更新Tensor,但可以更新变量。

/static/运算符只能更新变量的第一维。 您必须始终将参考张量传递给散点图更新(scatter_update而不是a)。

这是更新变量的第一个元素的方法:

a[line]

输出:

import tensorflow as tf

g = tf.Graph()
with g.as_default():
    a = tf.Variable(initial_value=[[0, 0, 0, 0],[0, 0, 0, 0]])
    b = tf.scatter_update(a, [0, 1], [[1, 0, 0, 0], [1, 0, 0, 0]])

with tf.Session(graph=g) as sess:
   sess.run(tf.initialize_all_variables())
   print sess.run(a)
   print sess.run(b)

但是不得不再次改变整个张量,只需分配一个全新的张量可能会更快。

答案 1 :(得分:1)

我找到了here的内容 我做了一个变量名U = [[1,2,3],[4,5,6]],并希望像U [:,1] = [2,3]那样更新它,所以我做了U [:, 1] .assign(cast_into_tensor [2,3])

这里有一个简单的代码

x = tf.Variable([[1,2,3],[4,5,6]])
print K.eval(x)
y = [0, 0]
with tf.control_dependencies([x[:,1].assign(y)]):
    x = tf.identity(x)
print K.eval(x)