我正在尝试使用TensorFlow中的scatter_nd函数来重新排序Matrix行内的元素。例如,假设我有代码:
indices = tf.constant([[1],[0]])
updates = tf.constant([ [5, 6, 7, 8],
[1, 2, 3, 4] ])
shape = tf.constant([2, 4])
scatter1 = tf.scatter_nd(indices, updates, shape)
$ print(scatter1) = [[1,2,3,4]
[5,6,7,8]]
这会重新排序updates
矩阵的行。
我不想仅对行进行重新排序,而是还要对每行中的各个元素进行重新排序。如果我只有一个向量(等级为1的Tensor),则此示例有效:
indices = tf.constant([[1],[0],[2],[3]])
updates = tf.constant([5, 6, 7, 8])
shape = tf.constant([4])
scatter2 = tf.scatter_nd(indices, updates, shape)
$ print(scatter2) = [6,5,7,8]
我真正关心的是能够交换scatter1
中每行中的元素,就像我在scatter2
中所做的那样,但是对scatter1
的每一行都要这样做。我尝试了indices
的各种组合,但不断收到scatter_nd
函数引发的大小不一致的错误。
答案 0 :(得分:1)
以下使用scatter_nd
交换每行每行的元素indices = tf.constant([[[0, 1], [0, 0], [0, 2], [0, 3]],
[[1, 1], [1, 0], [1, 2], [1, 3]]])
updates = tf.constant([ [5, 6, 7, 8],
[1, 2, 3, 4] ])
shape = tf.constant([2, 4])
scatter1 = tf.scatter_nd(indices, updates, shape)
with tf.Session() as sess:
print(sess.run(scatter1))
输出:
[[6 5 7 8]
[2 1 3 4]]
indices
中坐标的位置定义updates
中取值的位置,实际坐标定义值放在scatter1
中的位置。
这个答案迟了几个月,但希望仍然有用。
答案 1 :(得分:0)
假设您要在第二维中交换元素,或者保持第一个维度顺序。
import tensorflow as tf
sess = tf.InteractiveSession()
def prepare_fd(fd_indices, sd_dims):
fd_indices = tf.expand_dims(fd_indices, 1)
fd_indices = tf.tile(fd_indices, [1, sd_dims])
return fd_indices
# define the updates
updates = tf.constant([[11, 12, 13, 14],
[21, 22, 23, 24],
[31, 32, 33, 34]])
sd_dims = tf.shape(updates)[1]
sd_indices = tf.constant([[1, 0, 2, 3], [0, 2, 1, 3], [0, 1, 3, 2]])
fd_indices_range = tf.range(0, limit=tf.shape(updates)[0])
fd_indices_custom = tf.constant([2, 0, 1])
# define the indices
indices1 = tf.stack((prepare_fd(fd_indices_range, sd_dims), sd_indices), axis=2)
indices2 = tf.stack((prepare_fd(fd_indices_custom, sd_dims), sd_indices), axis=2)
# define the shape
shape = tf.shape(updates)
scatter1 = tf.scatter_nd(indices1, updates, shape)
scatter2 = tf.scatter_nd(indices2, updates, shape)
print(scatter1.eval())
# array([[12, 11, 13, 14],
# [21, 23, 22, 24],
# [31, 32, 34, 33]], dtype=int32)
print(scatter2.eval())
# array([[21, 23, 22, 24],
# [31, 32, 34, 33],
# [12, 11, 13, 14]], dtype=int32)
愿这个例子有所帮助。