如何最好地在python中找到一组数组(二维数组)中给定数组的出现次数(使用numpy)? 这是(简化)我需要在python代码中表达的内容:
patterns = numpy.array([[1, -1, 1, -1],
[1, 1, -1, 1],
[1, -1, 1, -1],
...])
findInPatterns = numpy.array([1, -1, 1, -1])
numberOfOccurrences = findNumberOfOccurrences(needle=findInPatterns, haystack=patterns)
print(numberOfOccurrences) # should print e.g. 2
实际上,我需要找出在集合中找到每个数组的频率。但是上面代码中描述的功能对我来说已经很有帮助了。
现在,我知道我可以使用循环来做到这一点但是想知道是否有更有效的方法来做到这一点?谷歌搜索只让我numpy.bincount,它完全符合我的需要,但不适用于二维数组,只适用于整数。
答案 0 :(得分:4)
使用1
和-1
s的数组,使用np.dot
表现不会有任何好处:如果(且仅当)所有项目都匹配,则点积将添加最多为行中的项目数。所以你可以做到
>>> haystack = np.array([[1, -1, 1, -1],
... [1, 1, -1, 1],
... [1, -1, 1, -1]])
>>> needle = np.array([1, -1, 1, -1])
>>> haystack.dot(needle)
array([ 4, -2, 4])
>>> np.sum(haystack.dot(needle) == len(needle))
2
这是一种基于卷积的图像匹配的玩具特殊情况,您可以轻松地重写它以查找比完整行短的模式,甚至可以使用FFT加速它。
答案 1 :(得分:3)
import numpy
A = numpy.array([[1, -1, 1, -1],
[1, 1, -1, 1],
[1, -1, 1, -1]])
b = numpy.array([1, -1, 1, -1])
print ((A == b).sum(axis=1) == b.size).sum()
这将进行行匹配,我们选择并计算所有值与我们要查找的模式匹配的行。这要求b
具有与A[0]
相同的形状。
答案 2 :(得分:2)
有点像@Hooked的回答,但稍微冗长一点。
np.sum(np.equal(A, b).all(axis=1))
答案 3 :(得分:1)
怎么样:
>>> from collections import Counter
>>> c
[[1, -1, 1, -1], [1, -1, 1, 1], [2, 3, 4, 5], [1, -1, 1, -1]]
>>> Counter(list(tuple(i) for i in c))[tuple(c[0])]
2
答案 4 :(得分:0)
这个怎么样?它不使用numpy,但它很简单,适用于任何大小/形状的矩阵。思路非常简单:不是比较数组,而是比较元组(可以清除,因此很容易在本地进行比较)。
patterns = [[1, -1, 1, -1],
[1, 2, 3, 4],
[1, -1, 1, -1],
[1],
[],
[1, 1, 1, -1],
[1, -1, 1, -1]]
key = [1, -1, 1, -1]
def find_number_of_occurrences(needle, haystack):
needle = tuple(needle)
return len([straw for straw in haystack if tuple(straw) == needle])
print find_number_of_occurrences(key, patterns) # Prints 3
这只是通过haystack
并从匹配的元素(needles
,如果你愿意)中构建一个列表解析,并返回该列表的长度。我不确定它在效率方面与numpy功能相比如何,但它在代码中肯定是清晰易懂的。
答案 5 :(得分:0)
>>> import numpy
>>> haystack = numpy.array([[1, -1, 1, -1], [1, 1, -1, 1], [1, -1, 1, -1]])
>>> needle = numpy.array([1, -1, 1, -1])
>>> sum([numpy.equal(hay, needle).all() for hay in haystack])
2
使用numpy.equal()
进行比较会根据输入的比较返回True
或False
元素的数组。 all()
检查数组中所有元素的真实性,然后检查布尔元素列表中的sum()
。