使用sklearn.mixture.GMM

时间:2016-01-07 20:42:09

标签: python machine-learning scikit-learn unsupervised-learning mixture-model

我使用sklearn.mixture.GMM来填充某些数据,并且无法从GMM中为数据集中的一个项目进行采样。

在超过1000个数据实例中它工作正常,但在下面的情况下(data_not_working)我在运行以下代码时遇到错误:

from sklearn import mixture
import numpy as np 

data_not_working = np.array([[-13.3669, -0.152287, -0.926697, 0.0967975, 0.375109, 0.22213, 0.364592, 0.283643, 0.614218, -0.117485, 0.221134, 0.104302], [-7.32323, -0.515594, -0.864193, 0.102628, 0.32041, 0.0606005, 0.197593, 0.025868, 0.249107, -0.0754152, 0.0994283, 0.0511292], [-5.70166, -0.408034, -1.22175, 0.220845, 0.2968, 0.0308518, 0.013137, -0.672265, -0.180614, -0.231932, -0.141483, 0.318216], [-3.84773, -0.13171, -1.37403, 0.242801, 0.399666, -0.150793, -0.342479, -0.689551, -0.246872, 0.00635363, 0.148948, 0.221603], [-3.12773, 0.172297, -1.38291, 0.00240961, 0.475504, 0.18957, -0.593592, -0.378285, -0.195662, -0.10973, 0.369654, 0.143974], [-2.43561, 0.0644245, -0.95012, 0.289466, 0.292279, -0.0631116, -0.546317, -0.138747, -0.104671, -0.0917557, 0.101156, -0.0469524], [-2.76789, -0.0416676, -1.18993, 0.392875, 0.136845, -0.263689, -0.402386, 0.206513, 0.335653, 0.0999453, 0.0125673, 0.226993], [-2.57943, -0.102039, -1.46225, 0.550504, 0.103789, 0.0240493, -0.116903, 0.25877, 0.189019, -0.107692, -0.134221, 0.333413], [-2.44367, 0.119016, -0.61038, 0.896835, 0.0487419, 0.281915, -0.0475086, -0.145234, 0.126528, -0.109666, 0.0714544, 0.102345], [-2.73143, 0.317259, -0.546473, 0.842293, -0.228764, 0.0580869, -0.128803, -0.523804, 0.0935071, -0.0131786, -0.0838011, -0.299564], [-2.86395, 0.282303, -1.00826, 0.65241, -0.317471, -0.0948204, 0.186242, -0.214155, 0.0747489, -0.163622, -0.00290485, -0.0116438], [-2.96273, 0.210327, -0.76213, 0.743427, -0.435498, -0.249532, 0.249474, -0.160216, -0.12336, -0.240312, -0.270668, -0.133469], [-3.35801, 0.362276, -0.507548, 0.301616, -0.583986, -0.424966, 0.0257714, -0.11669, 0.201161, 0.0104573, -0.267932, 0.164152], [-3.52099, 0.489393, -0.45938, 0.0439511, -0.250481, -0.490404, -0.0479253, 0.13449, -0.229827, -0.116102, -0.0683664, -0.0311946], [-3.01492, -0.0464895, -0.166774, -0.147464, -0.258049, -0.401865, 0.0168582, 0.277897, -0.0941365, -0.375444, -0.0174562, 0.0673491], [-3.30715, 0.26851, -0.803025, -0.0088587, -0.258561, -0.369787, 0.0882617, 0.223542, 0.0424378, -0.179769, 0.138257, 0.0615963], [-4.87222, 0.403703, -1.07541, 0.0120966, 0.00684427, -0.111497, 0.164573, 0.410325, -0.364741, 0.0662429, 0.0136844, 0.384867], [-5.87392, -0.310827, -1.04405, 0.176996, -0.131957, 0.2619, 0.0554216, 0.140458, -0.17792, 0.0856086, -0.375274, -0.0801583], [-7.16114, 0.866077, -1.83373, 0.625741, 0.0481332, 0.0240574, -0.135544, 0.294257, 0.0575935, -0.146078, -0.355156, 0.198461]])

def gmmSample(data):
    gmm = mixture.GMM(n_components=3, covariance_type='full', n_iter=100)
    gmm.fit(np.array(data)) 
    gmm.sample(100000)

gmmSample(data_not_working)

这会产生以下运行时错误:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Library/Python/2.7/site-packages/sklearn/mixture/gmm.py", line 411, in sample
    num_comp_in_X, random_state=random_state).T
  File "/Library/Python/2.7/site-packages/sklearn/mixture/gmm.py", line 102, in sample_gaussian
    s, U = linalg.eigh(covar)
  File "/System/Library/Frameworks/Python.framework/Versions/2.7/Extras/lib/python/scipy/linalg/decomp.py", line 387, in eigh
    raise LinAlgError("unrecoverable internal error.")
numpy.linalg.linalg.LinAlgError: unrecoverable internal error.

所以问题在于从GMM中取样,而不适合它。下面是一个数据实例的示例,其中上面的代码工作正常(就像我正在使用的所有其他1k +实例中一样)。所有实例都具有相同的形状:

data_working = np.array([[-13.8942, 0.329383, -0.467724, -0.0533347, 0.135847, 0.063669, 0.205088, 0.0188045, 0.200259, -0.153357, 0.0282053, 0.19137], [-10.0263, -0.232325, -1.23603, -0.373344, -0.270465, 0.223835, 0.245468, -0.14771, -0.21643, 0.0690714, 0.00436133, -0.0100653], [-7.2949, -1.02805, -0.360764, -0.211618, -0.0396331, 0.138607, 0.0274424, -0.0949814, -0.0290368, -0.195617, -0.064841, -0.0334741], [-5.27361, -1.45856, -0.0538218, 0.325073, -0.0113113, -0.182038, 0.0113554, 0.0380641, -0.155189, -0.000775465, -0.0834289, -0.00448654], [-3.4687, -1.80423, 0.181359, 0.216309, -0.0175896, -0.14976, 0.011689, -0.123908, -0.234207, 0.0114323, -0.157273, 0.153515], [-5.46375, -1.50817, -0.26668, 0.114913, 0.041553, 0.232375, 0.193539, -0.022985, -0.123261, 0.0131678, -0.225528, 0.0131385], [-8.96966, -0.926118, -1.14693, -0.0732326, -0.069377, 0.202194, -0.0373959, 0.155714, -0.0575818, 0.153754, 0.0827817, -0.0899819], [-5.4489, -1.46598, -0.904309, -0.180178, -0.0387, 0.284963, -0.0209437, 0.161178, -0.334906, 0.0925891, 0.0626761, -0.20815], [-6.67765, -0.909459, -0.893041, -0.528669, -0.287356, -0.317459, 0.0218326, 0.212814, -0.0544577, 0.0569478, -0.21171, -0.166358], [-5.83495, -1.40242, -1.08698, -0.295603, -0.44182, 0.0875251, -0.307424, 0.0605037, 0.142951, 0.0753836, -0.0953188, 0.00819761], [-5.92017, -1.05822, -0.898107, -0.0233588, -0.318233, -0.266055, -0.458731, 0.132217, -0.107108, -0.154634, -0.00669574, 0.142476], [-6.2026, -1.71479, -0.465533, -0.26163, 0.303861, -0.00872642, 0.155504, 0.614625, -0.207519, -0.212606, -0.0592188, 0.0887861], [-10.7305, -1.13431, -0.979158, 0.219761, -0.342731, -0.175846, 0.0111934, 0.226708, -0.0161784, -0.248745, 0.0470983, -0.0252792], [-8.0586, -1.45944, -1.18256, 0.0650664, 0.259971, -0.285369, -0.202342, 0.0675689, -0.238931, -0.0665339, 0.0854533, 0.0714763], [-5.61462, -1.77467, -1.17853, -0.402395, 0.0316058, -0.358417, -0.212316, 0.215444, 0.0111266, -0.17753, 0.106201, 0.102555], [-7.32914, -1.46897, -1.03672, 0.209392, -0.032743, -0.0519038, -0.30758, -0.377465, -0.329729, 0.0569532, -0.0359641, 0.182907], [-6.88854, -1.81873, -0.421743, -0.312312, -0.218102, 0.10227, -0.200002, -0.161226, -0.319451, 0.21934, -0.203555, -0.0566904], [-5.54895, -1.97478, -0.552426, -0.232346, -0.192567, -0.213922, -0.118116, 0.0830695, 0.0688067, 0.163558, 0.0393377, 0.269313], [-6.06666, -1.81661, -0.410524, -0.135279, -0.0956775, -0.269271, -0.164703, -0.0854252, -0.113826, 0.003071, 0.0617395, 0.247204]])

如果我将从GMM中取出的样本数量减少到10,那么它有时会起作用,但不是每次都有效!

从更多的数据来看,看起来data_not_working的组件数量可能是&lt; = 2.将#组件丢弃为2时,它运行时没有错误。因此尝试使用3个组件对此进行建模可能是造成问题。但是,我仍然不明白导致此错误的原因以及库中是否存在错误。

我现在也试过在不同的系统上运行相同的代码。它似乎适用于某些人而非其他人。这似乎没有受到python或libary版本的影响(2台机器运行相同的光盘映像,python,scipy,numpy和sklearn版本; 1个工作,其他没有)......非常奇怪。

我是否遗漏了明显的内容或图书馆存在问题?感谢

0 个答案:

没有答案