我正在研究GAN(我是python的初学者),并且在以前的练习中发现了我不理解的代码的这一部分。具体来说,我不明白为什么使用第9行的布尔值( Xk = X [Y == k] ),原因如下:
class BayesClassifier:
def fit(self, X, Y):
# assume classes are numbered 0...K-1
self.K = len(set(Y))
self.gaussians = []
self.p_y = np.zeros(self.K)
for k in range(self.K):
Xk = X[Y == k]
self.p_y[k] = len(Xk)
mean = Xk.mean(axis=0)
cov = np.cov(Xk.T)
g = {'m': mean, 'c': cov}
self.gaussians.append(g)
# normalize p(y)
self.p_y /= self.p_y.sum()
我觉得我不太了解某些基本知识。
答案 0 :(得分:2)
您应考虑到X, Y, k
是NumPy数组,而不是标量,并且某些运算符对其超载。特别是==
和基于布尔的索引。 ==
将是逐元素比较,而不是整个数组比较。
查看其工作原理:
In [9]: Y = np.array([0,1,2])
In [10]: k = np.array([0,1,3])
In [11]: Y==k
Out[11]: array([ True, True, False])
因此,==
的结果是一个布尔数组。
In [12]: X=np.array([0,2,4])
In [13]: X[Y==k]
Out[13]: array([0, 2])
当条件为X
时,结果是一个数组,其中包含从True
中选择的元素
因此len(Xk)
将是X
和k
之间匹配元素的数量。
答案 1 :(得分:0)
谢谢Artem
你是对的。我通过另一个渠道找到了另一个答案,这里是:
这是一个Numpy数组-这是NumPy数组的一个特殊功能,称为 布尔索引,可让您仅过滤掉数组中的值 过滤器返回True的地方:
将numpy导入为np
a = np.array([1、2、3、4、5])过滤器= a> 3
打印(过滤器)
[False,False,False,True,True]
print(a [过滤器])
[4,5]