在numpy数组上过滤具有特定条件的tensorflow数组

时间:2019-06-10 19:38:08

标签: python arrays numpy tensorflow slice

我有一个Tensorflow数组名称tf-array和一个numpy数组名称np_array。我想在tf_array中找到关于np-array的特定行。

    tf-array = tf.constant(
                [[9.968594,  8.655439,  0.,        0.       ],
                 [0.,        8.3356,    0.,        8.8974   ],
                 [0.,        0.,        6.103182,  7.330564 ],
                 [6.609862,  0.,        3.0614321, 0.       ],
                 [9.497023,  0.,        3.8914037, 0.       ],
                 [0.,        8.457685,  8.602337,  0.       ],
                 [0.,        0.,        5.826657,  8.283971 ]])

我也有一个np数组:

np_array = np.matrix(
 [[2, 5, 1],
  [1, 6, 4],
  [0, 0, 0],
  [2, 3, 6],
  [4, 2, 4]]

现在,我要将元素tf-array n的组合(它们的索引)的值保持在(here n is 2)中,保留在np-array中。什么意思?

例如,在tf-array的第一列中,具有值的索引为:(0,3,4)np-array中是否有包含以下两个索引的任意组合的行:(0,3), (0,4) or (3,4)。实际上,没有这样的行。因此,该列中的所有元素都变为zero

tf-array中第二列的索引为(0,1) (0,5) (1,5)。如您所见,记录(1,5)在第一行的np-array中可用。这就是为什么我们将这些保留在tf-array中的原因。

所以最终结果应该是这样的:

[[0.        0.        0.        0.       ]
 [0.        8.3356    0.        8.8974   ]
 [0.        0.        6.103182  7.330564 ]
 [0.        0.        3.0614321 0.       ]
 [0.        0.        3.8914037 0.       ]
 [0.        8.457685  8.602337  0.       ]
 [0.        0.        5.826657  8.283971 ]]

由于大量数据,我正在寻找一种非常有效的方法。

更新1

我可以使用下面的代码来实现这一点,该代码为True赋予false有值和零掩码:

[[ True  True False False]
 [False  True False  True]
 [False False  True  True]
 [ True False  True False]
 [ True False  True False]
 [False  True  True False]
 [False False  True  True]]

with tf.Session() as sess:  
 where = tf.not_equal(tf-array, 0.0)
 print(sess.run(where))

但是我怎么能将Theese矩阵与np_array进行比较?

提前谢谢!

2 个答案:

答案 0 :(得分:1)

您可以尝试的一种有效方法是使每行的位标志的值(0,3,4)等于1 << 0 | 1 << 3 | 1 << 4。您将拥有带有标志的值数组。请尝试<<和|运算符在numpy中工作。 使另一个数组相同,我想tf-数组只是包装的numpys。 在具有2个标志数组之后,对这些标志进行按位“与”运算。如果条件对行为真,则结果将至少具有两个非零位。谷歌也可以做到高效,谷歌。

此工具无法与float一起使用-您需要将其转换为非常小的整数。

import numpy as np



arr_one =  np.array(
 [[2, 5, 1],
  [1, 6, 4],
  [0, 0, 0],
  [2, 3, 6],
  [4, 2, 4]])

arr_two =  np.array(
 [[2, 0, 7],
  [1, 3, 4],
  [5, 5, 6],
  [1, 3, 6],
  [4, 2, 4]])




print('1 << arr_one.T[0] ' , 1 << arr_one.T[0] )


arr_one_flags = 1 << arr_one.T[0] | 1 << arr_one.T[1] | 1 << arr_one.T[2]

print('arr_one_flags ', arr_one_flags)

arr_two_flags = 1 << arr_two.T[0] | 1 << arr_two.T[1] | 1 << arr_two.T[2]

arr_and = arr_one_flags & arr_two_flags

print('arr_and ', arr_and)



def get_bit_count(value):
    n = 0
    while value:
        n += 1
        value &= value-1
    return n

arr_matches = np.array([get_bit_count(x) for x in arr_and])


print('arr_matches ', arr_matches )


arr_two_filtered = arr_two[arr_matches > 1]

print('arr_two_filtered ', arr_two_filtered )

答案 1 :(得分:1)

这是https://stackoverflow.com/a/56510832/7207392中的解决方案,并进行了必要的修改。为了简单起见,我对所有数据使用np.array。我不是tensortflow专家,所以如果翻译不完全是直截了当的,您将不得不问别人该怎么做。

import numpy as np

def f(a1, a2, n):
    N,M = a1.shape
    a1p = np.concatenate([a1,np.zeros((1,a1.shape[1]),a1.dtype)], axis=0)
    a2 = np.sort(a2, axis=1)
    a2[:,1:][a2[:,1:]==a2[:,:-1]] = N
    y,x = np.where(np.count_nonzero(a1p[a2], axis=1) >= n)
    out = np.zeros_like(a1p)
    out[a2[y],x[:,None]] = a1p[a2[y],x[:,None]]
    return out[:-1]

a1 = np.array(
    [[9.968594,  8.655439,  0.,        0.       ],
     [0.,        8.3356,    0.,        8.8974   ],
     [0.,        0.,        6.103182,  7.330564 ],
     [6.609862,  0.,        3.0614321, 0.       ],
     [9.497023,  0.,        3.8914037, 0.       ],
     [0.,        8.457685,  8.602337,  0.       ],
     [0.,        0.,        5.826657,  8.283971 ]])

a2 = np.array(
 [[2, 5, 1],
  [1, 6, 4],
  [0, 0, 0],
  [2, 3, 6],
  [4, 2, 4]])

print(f(a1,a2,2))

输出:

[[0.        0.        0.        0.       ]
 [0.        8.3356    0.        8.8974   ]
 [0.        0.        6.103182  7.330564 ]
 [0.        0.        3.0614321 0.       ]
 [0.        0.        3.8914037 0.       ]
 [0.        8.457685  8.602337  0.       ]
 [0.        0.        5.826657  8.283971 ]]