为什么在此贝叶斯分类器中使用此布尔值? (Python问题?)

时间:2019-01-24 00:26:06

标签: python boolean generator bayesian generative-adversarial-network

我正在研究GAN(我是python的初学者),并且在以前的练习中发现了我不理解的代码的这一部分。具体来说,我不明白为什么使用第9行的布尔值( Xk = X [Y == k] ),原因如下:

class BayesClassifier:
  def fit(self, X, Y):
    # assume classes are numbered 0...K-1
    self.K = len(set(Y))

    self.gaussians = []
    self.p_y = np.zeros(self.K)
    for k in range(self.K):
      Xk = X[Y == k]
      self.p_y[k] = len(Xk)
      mean = Xk.mean(axis=0)
      cov = np.cov(Xk.T)
      g = {'m': mean, 'c': cov}
      self.gaussians.append(g)
    # normalize p(y)
    self.p_y /= self.p_y.sum()
  1. 该布尔值根据Y的真实性返回0或1 == k,因此,Xk始终是X列表的第一个或第二个值。是的,没有找到有用的工具。
  2. 在第10行中,len(Xk)始终为1,为什么它使用该参数而不是单个参数?
  3. 下一行的均值和协方差每次只用一个值来计算。

我觉得我不太了解某些基本知识。

2 个答案:

答案 0 :(得分:2)

您应考虑到X, Y, k是NumPy数组,而不是标量,并且某些运算符对其超载。特别是==和基于布尔的索引。 ==将是逐元素比较,而不是整个数组比较。

查看其工作原理:

In [9]: Y = np.array([0,1,2])                                                                                        
In [10]: k = np.array([0,1,3])                                                                                       
In [11]: Y==k                                                                                                        

Out[11]: array([ True,  True, False])

因此,==的结果是一个布尔数组。

In [12]: X=np.array([0,2,4])                                                                                         
In [13]: X[Y==k]                                                                                                     

Out[13]: array([0, 2])

当条件为X时,结果是一个数组,其中包含从True中选择的元素

因此len(Xk)将是Xk之间匹配元素的数量。

答案 1 :(得分:0)

谢谢Artem

你是对的。我通过另一个渠道找到了另一个答案,这里是:

  

这是一个Numpy数组-这是NumPy数组的一个特殊功能,称为   布尔索引,可让您仅过滤掉数组中的值   过滤器返回True的地方:

     

https://docs.scipy.org/doc/numpy-1.13.0/user/basics.indexing.html?fbclid=IwAR3sGlgSwhv3i7IETsIxp4ROu9oZvNaaaBxZS01DrM5ShjWWRz22ShP2rIg#boolean-or-mask-index-arrays

     

将numpy导入为np

     

a = np.array([1、2、3、4、5])过滤器= a> 3

     

打印(过滤器)

     
    
      
        

[False,False,False,True,True]

      
    
  
     

print(a [过滤器])

     
    
      
        

[4,5]