如何在Python中使用MLE估计高斯分布参数

时间:2018-07-14 19:29:14

标签: python machine-learning statistics pattern-recognition

我有一组具有高斯分布的数据,这是一个直方图,显示了它们的实际外观:

Gaussian distribution

我必须使用贝叶斯分类器将这些数据分为两类,我正在使用sklearn进行分类,并且运行良好。但是,作为工作的一部分,我必须使用MLE估计数据的分布参数(σ,μ),并在我的分类器中使用它们。

那么,有没有可以使用最大似然法估算高斯分布参数的python库或伪代码,因此我可以在分类器中使用估算值?

我正在从Matlab寻找类似mle(data,'distribution',dist)的东西。

phat = mle(MPG,'distribution','burr')
phat =
34.6447    3.7898    3.5722

1 个答案:

答案 0 :(得分:0)

由于您的数据是多维的(D,在您的情况下具体为D = 15),因此您需要对数据的均值(D维)和协方差(D ^ 2维)建模。

您可以按照以下方式使用numpy轻松实现它

import numpy as np

def gaussian_mle(data):                                                                                                                                                                               
    mu = data.mean(axis=0)                                                                                                                                                                            
    var = (data-mu).T @ (data-mu) / data.shape[0] #  this is slightly suboptimal, but instructive

    return mu, var                                                                                                                                                                                    

要查看其效果,请在一些人工数据上运行它:

mean = [1.0, 3.14]                                                                                                                                                                                    
cov = [[2.0, 0.5], [0.5, 10]]                                                                                                                                                                         
data = np.random.multivariate_normal(mean, cov, 10000)                                                                                                                                                

print(gaussian_mle(data))

自定义格式后,这给了我们(我们随机抽样,结果可能会略有不同):

(
    array([1.00981014, 3.1217965 ]), #  sample mean
    array([[2.0266404 , 0.43036865], 
           [0.43036865, 9.87599803]]) #  sample covariance
)