交换矩阵行和列中的元素 - TensorFlow scatter_nd

时间:2017-02-13 15:22:55

标签: python matrix tensorflow

我正在尝试使用TensorFlow中的scatter_nd函数来重新排序Matrix行内的元素。例如,假设我有代码:

indices = tf.constant([[1],[0]])
updates = tf.constant([ [5, 6, 7, 8],
                        [1, 2, 3, 4] ])
shape = tf.constant([2, 4])
scatter1 = tf.scatter_nd(indices, updates, shape)
$ print(scatter1) = [[1,2,3,4]
                     [5,6,7,8]]

这会重新排序updates矩阵的行。

我不想仅对行进行重新排序,而是还要对每行中的各个元素进行重新排序。如果我只有一个向量(等级为1的Tensor),则此示例有效:

indices = tf.constant([[1],[0],[2],[3]])
updates = tf.constant([5, 6, 7, 8])
shape = tf.constant([4])
scatter2 = tf.scatter_nd(indices, updates, shape)
$ print(scatter2) = [6,5,7,8]

我真正关心的是能够交换scatter1中每行中的元素,就像我在scatter2中所做的那样,但是对scatter1的每一行都要这样做。我尝试了indices的各种组合,但不断收到scatter_nd函数引发的大小不一致的错误。

2 个答案:

答案 0 :(得分:1)

以下使用scatter_nd

交换每行每行的元素
indices = tf.constant([[[0, 1], [0, 0], [0, 2], [0, 3]], 
                      [[1, 1], [1, 0], [1, 2], [1, 3]]])
updates = tf.constant([ [5, 6, 7, 8],
                        [1, 2, 3, 4] ])
shape = tf.constant([2, 4])
scatter1 = tf.scatter_nd(indices, updates, shape)
with tf.Session() as sess:
    print(sess.run(scatter1))

输出:
[[6 5 7 8] [2 1 3 4]]

indices中坐标的位置定义updates中取值的位置,实际坐标定义值放在scatter1中的位置。

这个答案迟了几个月,但希望仍然有用。

答案 1 :(得分:0)

假设您要在第二维中交换元素,或者保持第一个维度顺序。

import tensorflow as tf
sess = tf.InteractiveSession()


def prepare_fd(fd_indices, sd_dims):
    fd_indices = tf.expand_dims(fd_indices, 1)
    fd_indices = tf.tile(fd_indices, [1, sd_dims])
    return fd_indices

# define the updates
updates = tf.constant([[11, 12, 13, 14],
                       [21, 22, 23, 24],
                       [31, 32, 33, 34]])
sd_dims = tf.shape(updates)[1]

sd_indices = tf.constant([[1, 0, 2, 3], [0, 2, 1, 3], [0, 1, 3, 2]])
fd_indices_range = tf.range(0, limit=tf.shape(updates)[0])
fd_indices_custom = tf.constant([2, 0, 1])

# define the indices
indices1 = tf.stack((prepare_fd(fd_indices_range, sd_dims), sd_indices), axis=2)
indices2 = tf.stack((prepare_fd(fd_indices_custom, sd_dims), sd_indices), axis=2)

# define the shape
shape = tf.shape(updates)

scatter1 = tf.scatter_nd(indices1, updates, shape)
scatter2 = tf.scatter_nd(indices2, updates, shape)

print(scatter1.eval())

# array([[12, 11, 13, 14],
#        [21, 23, 22, 24],
#        [31, 32, 34, 33]], dtype=int32)

print(scatter2.eval())

# array([[21, 23, 22, 24],
#        [31, 32, 34, 33],
#        [12, 11, 13, 14]], dtype=int32)

愿这个例子有所帮助。