我有来自'n'个不同人的标签,这些人对'm'个项目(0或1)进行了评分,因此是一个m x n的数组。例如,有3个人对4个项目进行评分:
arr = np.asarray([[1,1,1], [1,1,0], [0,0,0], [0, 1, 0]])
print(arr)
>>>
[[1 1 1]
[1 1 0]
[0 0 0]
[0 1 0]]
我想看看每个人都“同意”哪些项目,即行中的所有值都相同。在此示例中,答案为[True,False,True,False]。我使用它来工作:
np.logical_or(arr.sum(axis=1) == n, arr.sum(axis=1) == 0)
有点骇客。有什么更好的方法?
答案 0 :(得分:2)
一种选择是沿行计算diff
,然后检查所有diff
是否等于0;这将确保一行中的所有元素都相同(并且可以不同于0和1):
(np.diff(arr, axis=1) == 0).all(axis=1)
# array([ True, False, True, False], dtype=bool)
或者如果您只有0和1,那么:
(arr == 1).all(1) | (arr == 0).all(1)
# array([ True, False, True, False], dtype=bool)
arr.all(1) | ~arr.any(1)
# array([ True, False, True, False], dtype=bool)
答案 1 :(得分:1)
我认为len(set(.))
基本上就是您正在寻找的is_uniform
函数:
[len(set(x)) == 1 for x in arr]
请注意,此解决方案非常笼统,不需要:
答案 2 :(得分:0)
或者使用list comprehension
制作长度等于i
的第一个元素的元素的元素(因此,基本上看一下它们在i
中是否都相同),如果它们与条件不匹配,则改为False
:
print([len(i)==i.tolist().count(i[0]) for i in arr])
输出:
[True, False, True, False]