如何在tensorflow中的3d张量中交换行

时间:2019-09-29 22:49:21

标签: python numpy tensorflow matrix

我已经待了两个小时了,这是一个简单的问题。我正在使用GA构建权重优化器并执行突变,因此我需要能够将一个NN中的一行权重与另一个NN交换

我对shape [population,total_input,total_output]的每一层都有一个3d张量。我在3d张量中选取一行,然后两个人必须交换完全相同的行的值。例如,行[nn1,row_to_swap]需要与行[nn2,row_to_swap]交换。

一个具有张量3的张量,输入节点3和输出节点2的示例具有此形状[3,3,2],在这里我想交换[0,0]和[1,0]:

      [[[ -0.08140966 -0.04416275 ],
        [ 0.08669635, -0.1681123 ],
        [ 0.06804892,  0.05393898]],

       [[ 0.11369397, -0.0822193 ],
        [-0.08230941,  0.16685687],
        [-0.08133464, -0.02710806]],

       [[ 0.08381592, -0.07583494],
        [-0.08355351,  0.07891247],
        [ 0.0392112 , -0.07686558]]]

应该看起来像这样。

  [[[ -0.08140966 -0.04416275 ],
    [ 0.08669635, -0.1681123 ],
    [ 0.06804892,  0.05393898]],

   [[ 0.11369397, -0.0822193 ],
    [-0.08230941,  0.16685687],
    [-0.08133464, -0.02710806]],

   [[ 0.08381592, -0.07583494],
    [-0.08355351,  0.07891247],
    [ 0.0392112 , -0.07686558]]]

只要抬起头,我就不知道张量的确切形状是什么,因为它们将使用shape变量创建。有时,这些方法可能必须进行多次交换。假设[1,0]和[1,2]必须与[0,0]和[0,2]交换,因此,如果有一种方法可以一次性进行多次交换,而又不会造成循环的话。

例如:

      [[[ -0.08140966 -0.04416275 ],
        [ 0.08669635, -0.1681123 ],
        [ 0.06804892,  0.05393898]],

       [[ 0.11369397, -0.0822193 ],
        [-0.08230941,  0.16685687],
        [-0.08133464, -0.02710806]],

       [[ 0.08381592, -0.07583494],
        [-0.08355351,  0.07891247],
        [ 0.0392112 , -0.07686558]]]

应该看起来像这样。

      [[[ 0.11369397, -0.0822193 ],
        [-0.08230941,  0.16685687]
        [-0.08133464, -0.02710806]],

       [[ -0.08140966 -0.04416275 ],
        [-0.08230941,  0.16685687],
        [ 0.06804892,  0.05393898]],

       [[ 0.08381592, -0.07583494],
        [-0.08355351,  0.07891247],
        [ 0.0392112 , -0.07686558]]]

Numpy似乎有一个简单的解决方案,他们使用

npArray[[0,0]] = npArray[[1,0]]

当然,TensorFlow有点复杂。

1 个答案:

答案 0 :(得分:0)

这可以使用tf.scatter_nd_update实现。请找到示例片段:

ref = tf.Variable([[[1, 2,3],[3, 4,5],[5, 6,7], [7, 8,9]]])
print(ref)
indices = tf.constant([[0,1,1], [0,1,0], [0,0,0] ,[0,0,1]])
updates = tf.constant([9, 10, 11, 12])
update = tf.scatter_nd_update(ref, indices, updates)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(ref))
    print(sess.run(update))