如何使用numpy或Theano将矩阵的整数值作为另一个矩阵的索引?

时间:2016-12-18 13:02:46

标签: python numpy theano

我有以下4个相同形状的矩阵:(1)包含整数值的矩阵I,(2)包含整数值的矩阵J,(3)矩阵{{1包含浮点值和(4)包含浮点值的矩阵D

我想用这4个矩阵构建一个"输出"矩阵以下列方式:

  1. 要查找输出矩阵的元素Vi的值,请找到矩阵j中等于I的所有单元格(元素)和矩阵i的所有单元格(元素)等于J
  2. 仅使用满足这两个条件的单元格(请记住矩阵jI具有相同的形状)。
  3. 在"选择"之间搜索单元格具有最小值J的单元格。
  4. 获取找到的单元格(D值最小)并检查它在矩阵D中的值。
  5. 通过这种方式,我们可以找到输出矩阵的Vi元素的值。我为所有jis执行此操作。

    我想使用numpy或Theano来解决这个问题。

    当然我可以循环遍历所有i_s和j_s,但我认为(希望)应该有一种更有效的方式。

    ADDED

    根据要求,我举了一个例子:

    这是矩阵I:

    js

    这是矩阵J:

     0   1   2
     1   1   0
     0   0   2
    

    这是矩阵D:

     1   1   1
     1   2   1
     0   1   0
    

    最后我们得到矩阵V:

     1.2   3.4   2.2
     2.2   4.3   2.3
     7.1   6.1   2.7
    

    如您所见,所有4个矩阵具有相同的形状(3 x 4),但它们可以具有其他形状(例如2 x 5)。主要的是所有4个矩阵的形状都是相同的。

    正如我们所看到的,矩阵 1.1 8.1 9.1 3.1 7.1 2.1 0.1 5.1 3.1 的值从0到2,因此输出矩阵应该有3行。同样,我们可以得出结论,输出矩阵应该有3列(因为矩阵I的值也是0到2)。

    让我们首先找到输出矩阵的元素(0,1)。在J矩阵中,以下单元格(由x标记)包含0。

    I

    在矩阵 x . . . . x x x . 中,以下元素包含1:

    J

    这两组细胞的交集是:

     x   x   x
     x   .   x
     .   x   .
    

    相应的距离为:

     x   .   .
     .   .   x
     .   x   .
    

    因此,最小距离位于左上角。因此,我们从矩阵 1.2 . . . . 2.3 . 6.1 . 的左上角获取值(此值为1.1)。

    这就是我们如何找到输出矩阵的(0,1)元素的值。我们对所有可能的索引组合(总共有3 x 3 = 9)组合执行相同的过程。对于某些组合,我们找不到任何值,在这种情况下,我们将值设置为V

1 个答案:

答案 0 :(得分:2)

这是使用broadcasting -

的矢量化方法
# Get mask of matching elements against the iterators
m,n = I.shape
Imask = I == np.arange(m)[:,None,None,None]
Jmask = J == np.arange(n)[:,None,None]

# Get the mask of intersecting ones
mask = Imask & Jmask

# Get D intersection masked array
Dvals = np.where(mask,D,np.inf)

# Get argmin along merged last two axes. Index into flattened V for final o/p
out = V.ravel()[Dvals.reshape(m,n,-1).argmin(-1)]

示例输入,输出 -

In [136]: I = np.array([[0,1,2],[1,1,0],[0,0,2]])
     ...: J = np.array([[1,1,1],[1,2,1],[0,1,0]])
     ...: D = np.array([[1.2, 3.4, 2.2],[2.2, 4.3, 2.3],[7.1, 6.1, 2.7]])
     ...: V = np.array([[1.1 , 8.1, 9.1],[3.1, 7.1, 2.1],[0.1, 5.1, 3.1]])
     ...: 

In [144]: out
Out[144]: 
array([[ 0.1,  1.1,  1.1], # To verify : v[0,1] = 1.1
       [ 1.1,  3.1,  7.1],
       [ 3.1,  9.1,  1.1]])