DBSCAN plot - plt.plot()中传递的颜色值抛出ValueError

时间:2017-10-06 19:13:48

标签: numpy scikit-learn cluster-analysis outliers dbscan

我正在使用DBSCAN对数据集执行群集。我认为这是因为在plt.plot()中传递给markerfacecolor的颜色参数不是单个值。如果我错了,请告诉我。我的功能是纬度,经度,speed_mph,speedlimit_mph,vehicle_id,driver_id。

这是我的群集代码

dbsc = DBSCAN(eps = .5, min_samples = 5).fit(df_cont)

labels = dbsc.labels_
print(labels)

num_clusters = len(set(labels))
clusters = pd.Series([df_cont[labels == n] for n in range(num_clusters)])
print('Number of clusters: {}'.format(num_clusters))
# No of clusters : 5687

core_samples = np.zeros_like(labels, dtype = bool)
core_samples[dbsc.core_sample_indices_] = True

unique_labels = np.unique(labels)

colors = plt.cm.Spectral(np.linspace(0,1, len(unique_labels)))

for (label, color) in zip(unique_labels, colors):
    class_member_mask = (labels == label)
    xy = df_cont[class_member_mask & core_samples]
    print("color:",color)
    # color: [ 0.61960784  0.00392157  0.25882353  1.        ]

    plt.plot(xy.values[:,0],xy.values[:,1], marker='o', markerfacecolor = color, markersize = 10)

    xy2 = df_cont[class_member_mask & ~core_samples]
    plt.plot(xy2.values[:,0],xy2.values[:,1], 'o', markerfacecolor = color, markersize = 5)

plt.title("DBSCAN Driver - Speed MPH")
plt.xlabel("driver")
plt.ylabel("Speed")
plt.show()

以下是抛出的错误消息

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-105-0192647e6baf> in <module>()
      3     xy = df_cont[class_member_mask & core_samples]
      4     print("color:",color)
----> 5     plt.plot(xy.values[:,0],xy.values[:,1], marker='o', markerfacecolor = color, markersize = 10)
      6 
      7     xy2 = df_cont[class_member_mask & ~core_samples]

/home/radiance/anaconda3/lib/python3.6/site-packages/matplotlib/pyplot.py in plot(*args, **kwargs)
   3315                       mplDeprecation)
   3316     try:
-> 3317         ret = ax.plot(*args, **kwargs)
   3318     finally:
   3319         ax._hold = washold

/home/radiance/anaconda3/lib/python3.6/site-packages/matplotlib/__init__.py in inner(ax, *args, **kwargs)
   1896                     warnings.warn(msg % (label_namer, func.__name__),
   1897                                   RuntimeWarning, stacklevel=2)
-> 1898             return func(ax, *args, **kwargs)
   1899         pre_doc = inner.__doc__
   1900         if pre_doc is None:

/home/radiance/anaconda3/lib/python3.6/site-packages/matplotlib/axes/_axes.py in plot(self, *args, **kwargs)
   1404         kwargs = cbook.normalize_kwargs(kwargs, _alias_map)
   1405 
-> 1406         for line in self._get_lines(*args, **kwargs):
   1407             self.add_line(line)
   1408             lines.append(line)

/home/radiance/anaconda3/lib/python3.6/site-packages/matplotlib/axes/_base.py in _grab_next_args(self, *args, **kwargs)
    405                 return
    406             if len(remaining) <= 3:
--> 407                 for seg in self._plot_args(remaining, kwargs):
    408                     yield seg
    409                 return

/home/radiance/anaconda3/lib/python3.6/site-packages/matplotlib/axes/_base.py in _plot_args(self, tup, kwargs)
    393         ncx, ncy = x.shape[1], y.shape[1]
    394         for j in xrange(max(ncx, ncy)):
--> 395             seg = func(x[:, j % ncx], y[:, j % ncy], kw, kwargs)
    396             ret.append(seg)
    397         return ret

/home/radiance/anaconda3/lib/python3.6/site-packages/matplotlib/axes/_base.py in _makeline(self, x, y, kw, kwargs)
    300         default_dict = self._getdefaults(None, kw)
    301         self._setdefaults(default_dict, kw)
--> 302         seg = mlines.Line2D(x, y, **kw)
    303         return seg
    304 

/home/radiance/anaconda3/lib/python3.6/site-packages/matplotlib/lines.py in __init__(self, xdata, ydata, linewidth, linestyle, color, marker, markersize, markeredgewidth, markeredgecolor, markerfacecolor, markerfacecoloralt, fillstyle, antialiased, dash_capstyle, solid_capstyle, dash_joinstyle, solid_joinstyle, pickradius, drawstyle, markevery, **kwargs)
    418         self._markerfacecoloralt = None
    419 
--> 420         self.set_markerfacecolor(markerfacecolor)
    421         self.set_markerfacecoloralt(markerfacecoloralt)
    422         self.set_markeredgecolor(markeredgecolor)

/home/radiance/anaconda3/lib/python3.6/site-packages/matplotlib/lines.py in set_markerfacecolor(self, fc)
   1204         if fc is None:
   1205             fc = 'auto'
-> 1206         if self._markerfacecolor != fc:
   1207             self.stale = True
   1208         self._markerfacecolor = fc

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

此外,我尝试使用我的lat进行群集,其中包括其他功能。 DBSCAN抛出错误,只允许两个功能。我应该问这是一个单独的问题吗?

dbsc = DBSCAN(eps = .5, min_samples = 5, algorithm='ball_tree', metric='haversine').fit(np.radians(df_cont))

df_cont的内容是 -

{'Day': [1, 1, 1, 1, 1],
 'Month': [6, 6, 6, 6, 6],
 'Year': [2015, 2015, 2015, 2015, 2015],
 'driver_id': [5693, 5693, 916461, 1145487, 1145487],
 'latitude': [34.640141, 34.64373, 34.551254, 35.613663, 35.614525],
 'longitude': [-77.938721,
  -77.9394,
  -78.78463,
  -78.470596,
  -78.47466999999999],
 'speed_mph': [64, 64, 1, 62, 61],
 'speedlimit_mph': [70, 70, 55, 70, 70],
 'vehicle_id': [1208979, 1208979, 1262441, 1280223, 1280223]}

1 个答案:

答案 0 :(得分:1)

我使用散点图修复了错误。 plt.scatter(xy.values[:,0],xy.values[:,1],s=10,c=color,marke‌​r='o')