不理解书中的numpy代码

时间:2018-04-12 14:17:01

标签: python numpy

X = np.array([[0, 1, 0, 1],
              [1, 0, 1, 1],
              [0, 0, 0, 1],
              [1, 0, 1, 0]])
y = np.array([0, 1, 0, 1])
counts = {}
for label in np.unique(y):
     counts[label] = X[y == label].sum(axis=0)
print("Feature counts: ", counts)'

此代码用于检查类的特征不为零的次数,但我不理解语法counts[label] = X[y == label].sum(axis=0)。当我只运行print(y==label)时,会出现numpy数组[False True False True],我不明白这个索引和numpy数组中的项目总和。此外,我不明白为什么`y == label'已设置。 任何帮助表示赞赏。谢谢。

2 个答案:

答案 0 :(得分:0)

您所看到的是经典的分组操作;令人遗憾的是,numpy并没有以优雅的方式开箱即用。你似乎不太关心高层次的理解而不是低层次的理解;但如果后者本身并不是一个目标,那么有一些替代方案可以将这些担忧从你身上剔除,例如numpy_indexed(免责声明:我是其作者)

import numpy_indexed as npi
labels, counts = npi.group_by(y).sum(X)

这将做同样的事情,但是以矢量化,因此更具可扩展性的方式。

答案 1 :(得分:0)

我不认为代码正在执行非零元素检查。对于你的其他部分问题。 y == 1会给你一个布尔掩码

如果要检查非零,请调用唯一函数并删除零元素计数

np.unique(X, return_counts=True)
(array([0, 1]), array([8, 8], dtype=int64))# first array is element, second is count

布尔掩码示例:

print y
[0 1 0 1]

这会将y中的所有元素与1进行比较,并返回一个真/假列表

print y == 1
[False  True False  True]

基于上面的真/假列表,它从X

中选择第1行和第3行
print X[y == 1]
[[1 0 1 1]
 [1 0 1 0]]