因此,我正在上一门数据科学课程,这周一直在研究高斯生成分类模型。我编写了以下函数以适合训练数据:
def fit_generative_model(x,y):
k = 10 # labels 0,1,...,k-1
d = (x.shape)[1] # number of features
mu = np.zeros((k,d))
sigma = np.zeros((k,d,d))
sigmac = np.zeros((k,d,d))
pi = np.zeros(k)
c = 100
# identity matrix for sigma
im = np.identity(d)
for lb in range(0,k):
indx = (y == lb)
mu[lb] = np.mean(x[indx,:], axis=0)
sigma[lb] = np.cov(x[indx,:], rowvar=0, bias=1)
#cim = im*c
#sigmac[lb]=np.matmul(sigma[lb],cim)
sigmac[lb]=sigma[lb]*c
pi[lb]=float(sum(indx))/float(len(y))
# Halt and return parameters
print(mu.shape,sigma.shape,pi.shape)
return mu, sigmac, pi
然后我写了:
# let's find out how many errors I have
kp = 10
tp = len(test_labels)
score = np.zeros((tp,kp))
for rw in range(0,tp):
for pred in range(0,kp):
score[rw,pred] = np.log(pi[pred]) + \
multivariate_normal.logpdf(test_data[rw,:], mean = mu[pred,:], cov = sigma[pred,:,:])
predictions = np.argmax(score[:,:], axis=0) + 1
errors = np.sum(predictions != test_labels)
print('Test errors using feature(s):')
for f in features:
print(featurenames[f],'errors:',str(errors)+'/'+str(tp))
这是我运行它时得到的追溯:
---------------------------------------------------------------------------
LinAlgError Traceback (most recent call last)
<ipython-input-15-36dfe1a7193c> in <module>
7 for pred in range(0,kp):
8 score[rw,pred] = np.log(pi[pred]) + \
----> 9 multivariate_normal.logpdf(test_data[rw,:], mean = mu[pred,:], cov = sigma[pred,:,:])
10
11 predictions = np.argmax(score[:,:], axis=0) + 1
c:\users\jbustos\appdata\local\programs\python\python37\lib\site-packages\scipy\stats\_multivariate.py in logpdf(self, x, mean, cov, allow_singular)
493 dim, mean, cov = self._process_parameters(None, mean, cov)
494 x = self._process_quantiles(x, dim)
--> 495 psd = _PSD(cov, allow_singular=allow_singular)
496 out = self._logpdf(x, mean, psd.U, psd.log_pdet, psd.rank)
497 return _squeeze_output(out)
c:\users\jbustos\appdata\local\programs\python\python37\lib\site-packages\scipy\stats\_multivariate.py in __init__(self, M, cond, rcond, lower, check_finite, allow_singular)
161 d = s[s > eps]
162 if len(d) < len(s) and not allow_singular:
--> 163 raise np.linalg.LinAlgError('singular matrix')
164 s_pinv = _pinv_1d(s, eps)
165 U = np.multiply(u, np.sqrt(s_pinv))
LinAlgError: singular matrix
有人可以帮助我了解我怎么弄错了吗?谢谢。