在数组中的列之间进行numpy查找协议

时间:2018-12-14 02:12:31

标签: python arrays numpy

我有来自'n'个不同人的标签,这些人对'm'个项目(0或1)进行了评分,因此是一个m x n的数组。例如,有3个人对4个项目进行评分:

arr = np.asarray([[1,1,1], [1,1,0], [0,0,0], [0, 1, 0]])
print(arr)
>>>
[[1 1 1]
 [1 1 0]
 [0 0 0]
 [0 1 0]]

我想看看每个人都“同意”哪些项目,即行中的所有值都相同。在此示例中,答案为[True,False,True,False]。我使用它来工作:

np.logical_or(arr.sum(axis=1) == n, arr.sum(axis=1) == 0)

有点骇客。有什么更好的方法?

3 个答案:

答案 0 :(得分:2)

一种选择是沿行计算diff,然后检查所有diff是否等于0;这将确保一行中的所有元素都相同(并且可以不同于0和1):

(np.diff(arr, axis=1) == 0).all(axis=1)
# array([ True, False,  True, False], dtype=bool)

或者如果您只有0和1,那么:

(arr == 1).all(1) | (arr == 0).all(1)
# array([ True, False,  True, False], dtype=bool)

arr.all(1) | ~arr.any(1)
# array([ True, False,  True, False], dtype=bool)

答案 1 :(得分:1)

我认为len(set(.))基本上就是您正在寻找的is_uniform函数:

[len(set(x)) == 1 for x in arr]

请注意,此解决方案非常笼统,不需要:

  1. 每个项目的投票人数相同
  2. 值为数字或任何特定类型的值
  3. 核心python之上的附加包

答案 2 :(得分:0)

或者使用list comprehension制作长度等于i的第一个元素的元素的元素(因此,基本上看一下它们在i中是否都相同),如果它们与条件不匹配,则改为False

print([len(i)==i.tolist().count(i[0]) for i in arr])

输出:

[True, False, True, False]