从具有二维数组的3维数组中选择

时间:2019-05-28 10:27:29

标签: python arrays numpy tensorflow

我有两个数组:

  • a:一个3维源数组( N x M x 2
  • b:包含0和1s的二维索引数组( N x M )。

我想使用b中的索引在其第三维中选择a的相应元素。所得数组的尺寸应为 N x M 。这是代码示例:

import numpy as np

a = np.array( # dims: 3x3x2
    [[[ 0,  1],
     [ 2,  3],
     [ 4,  5]],
    [[ 6,  7],
     [ 8,  9],
     [10, 11]],
    [[12, 13],
     [14, 15],
     [16, 17]]]
)
b = np.array( # dims: 3x3
    [[1, 1, 1],
    [1, 1, 1],
    [1, 1, 1]]
)

# select the elements in a according to b
# to achieve this result:
desired = np.array(
  [[ 1,  3,  5],
   [ 7,  9, 11],
   [13, 15, 17]]
)

起初,我认为这必须有一个简单的解决方案,但我根本找不到。因为我想将其移植到tensorflow,所以如果有人知道numpy类型的解决方案,我将不胜感激。

编辑a的第三维可能包含两个以上的元素。因此,b可能还包含不同于0和1的索引-它不是布尔掩码。

4 个答案:

答案 0 :(得分:2)

@jdehesa提示,我们可以使用np.ogrid来获取前两个轴的索引:

ax0, ax1 = np.ogrid[:b.shape[0], :b.shape[1]]

然后我们可以使用b直接沿最后一个轴索引。请注意,ax0ax1将以b的形式广播:

desired = a[ax0, ax1 ,b] 

print(desired)
array([[ 1,  3,  5],
       [ 7,  9, 11],
       [13, 15, 17]])

答案 1 :(得分:2)

我们可以使用np.where

np.where(b, a[:, :, 1], a[:, :, 0])

输出:

array([[ 1,  3,  5],
       [ 7,  9, 11],
       [13, 15, 17]])

答案 2 :(得分:2)

我为张量流添加了一些解决方案。

import tensorflow as tf

a = tf.constant([[[ 0,  1],[ 2,  3],[ 4,  5]],
                 [[ 6,  7],[ 8,  9],[10, 11]],
                 [[12, 13],[14, 15],[16, 17]]],dtype=tf.float32)
b = tf.constant([[1, 1, 1],[1, 1, 1],[1, 1, 1]],dtype=tf.int32)

# 1. use tf.gather_nd
colum,row = tf.meshgrid(tf.range(a.shape[0]),tf.range(a.shape[1]))
idx = tf.stack([row, colum, b], axis=-1) # Thanks for @jdehesa's suggestion
result1 = tf.gather_nd(a,idx)

# 2. use tf.reduce_sum
mask = tf.one_hot(b,depth=a.shape[-1],dtype=tf.float32)
result2 = tf.reduce_sum(a*mask,axis=-1)

# 3. use tf.boolean_mask
mask = tf.one_hot(b,depth=a.shape[-1],dtype=tf.float32)
result3 = tf.reshape(tf.boolean_mask(a,mask),b.shape)

with tf.Session() as sess:
    print('method 1: \n',sess.run(result1))
    print('method 2: \n',sess.run(result2))
    print('method 3: \n',sess.run(result3))

method 1: 
 [[ 1.  3.  5.]
 [ 7.  9. 11.]
 [13. 15. 17.]]
method 2: 
 [[ 1.  3.  5.]
 [ 7.  9. 11.]
 [13. 15. 17.]]
method 3: 
 [[ 1.  3.  5.]
 [ 7.  9. 11.]
 [13. 15. 17.]]

答案 3 :(得分:1)

您可以使用np.take_along_axis

   "content":" <p><img alt=""   
              src="data:image\/png;base64,
              iVBORw0KGgoAAAANSUhEUgAAA7YAAAVx