我尝试使用scatter_update更新tensorflow对角线权重矩阵,但到目前为止没有任何运气。它要么提示形状不匹配,要么仅沿第一行更新。这是非常奇怪的API行为。有人可以帮我吗?谢谢
Example:
dia_mx = tf.Variable(initial_value=np.array([[1.,0.,0.],
[0.,1.,0.],
[0.,0.,1.]]))
new_diagonal_values = np.array([2., 3., 4.])
tf.scatter_update(dia_mx, [[0,0],[1,1],[2,2]], new_diagonal_values)
Get error:
InvalidArgumentError: shape of indices ([3,2]) is not compatible with the shape of updates ([3]) [Op:ResourceScatterUpdate]
Expect new diagonal matrix:
dia_mx = [[2.,0.,0.],
[0.,3.,0.],
[0.,0.,4.]]
答案 0 :(得分:0)
要使用张量更新特定索引,请使用 tf.scatter_nd_update()
:
import tensorflow as tf
import numpy as np
dia_mx = tf.Variable(initial_value=np.array([[1.,0.,0.],
[0.,1.,0.],
[0.,0.,1.]]))
updates = [tf.constant(2.), tf.constant(3.), tf.constant(4.)]
indices = tf.constant([[0, 0], [1, 1], [2, 2]])
update_tensor = tf.scatter_nd_update(dia_mx, indices, updates)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(update_tensor.eval())
# [[2. 0. 0.]
# [0. 3. 0.]
# [0. 0. 4.]]
tf.scatter_update()
沿张量的第一个维度应用更新。在这种特殊情况下,这意味着立即将更新应用于矩阵的整个行:
dia_mx = tf.Variable(initial_value=np.array([[1.,0.,0.],
[0.,1.,0.],
[0.,0.,1.]]), dtype=tf.float32)
updates = tf.constant([[2., 0., 0.], [0., 3., 0.], [0., 0., 4.]], dtype=tf.float32)
indices = tf.constant([0, 1, 2])
update_tensor = tf.scatter_update(dia_mx, indices, updates)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(update_tensor.eval())
# [[2. 0. 0.]
# [0. 3. 0.]
# [0. 0. 4.]]