我' m跟随此Manipulating matrix elements in tensorflow。使用tf.scatter_update。但我的问题是: 如果我的tf.Variable是2D,会发生什么?让我们说:
a = tf.Variable(initial_value=[[0, 0, 0, 0],[0, 0, 0, 0]])
我如何更新例如每一行的第一个元素并将其赋值为1?
我试过像
这样的东西for line in range(2):
sess.run(tf.scatter_update(a[line],[0],[1]))
但它失败了(我期待那样)并且给了我错误:
TypeError:输入' ref' ' ScatterUpdate'操作需要l值输入
我该如何解决这类问题?
`
答案 0 :(得分:12)
在tensorflow中,您无法更新Tensor,但可以更新变量。
/static/
运算符只能更新变量的第一维。
您必须始终将参考张量传递给散点图更新(scatter_update
而不是a
)。
这是更新变量的第一个元素的方法:
a[line]
输出:
import tensorflow as tf
g = tf.Graph()
with g.as_default():
a = tf.Variable(initial_value=[[0, 0, 0, 0],[0, 0, 0, 0]])
b = tf.scatter_update(a, [0, 1], [[1, 0, 0, 0], [1, 0, 0, 0]])
with tf.Session(graph=g) as sess:
sess.run(tf.initialize_all_variables())
print sess.run(a)
print sess.run(b)
但是不得不再次改变整个张量,只需分配一个全新的张量可能会更快。
答案 1 :(得分:1)
我找到了here的内容 我做了一个变量名U = [[1,2,3],[4,5,6]],并希望像U [:,1] = [2,3]那样更新它,所以我做了U [:, 1] .assign(cast_into_tensor [2,3])
这里有一个简单的代码
x = tf.Variable([[1,2,3],[4,5,6]])
print K.eval(x)
y = [0, 0]
with tf.control_dependencies([x[:,1].assign(y)]):
x = tf.identity(x)
print K.eval(x)