我正在尝试使用matplotlib和seaborn来创建散点图。如果整个绘图只有一种颜色如下所示,它可以正常工作:
sns.regplot(x = pair[0], y = pair[1], data = d, fit_reg = False, ax = ax, x_jitter = True, scatter_kws = {'linewidths':0, 's':2, 'color':'r'})
但是,如果我需要每个数据点的颜色取决于col
中的值,如:
col = pandas_df.prediction.map({0: [1,0,0], 1:[0,1,0]})
sns.regplot(x = pair[0], y = pair[1], data = d, fit_reg = False, ax = ax, x_jitter = True, scatter_kws = {'linewidths':0, 's':2, 'cmap':"RGB", 'color':col})
其中pandas_df是一个pandas数据帧,因此col
是一系列RGB点,如:
[1,0,0]
[0,1,0]
[1,0,0]
[0,1,0]
:
:
然后我得到了错误:
IndexErrorTraceback (most recent call last)
<ipython-input-12-e17a2dbdd639> in <module>()
15 #print dtype(col)
16 d.plot.scatter(*pair, ax=ax, c=col, linewidths=0, s=2, alpha = 0.7)
---> 17 sns.regplot(x = pair[0], y = pair[1], data = d, fit_reg = False, ax = ax, x_jitter = True, scatter_kws = {'linewidths':0, 's':2, 'cmap':"RGB", 'color':col})
18
19 fig.tight_layout()
/usr/local/lib/python2.7/dist-packages/seaborn/linearmodels.pyc in regplot(x, y, data, x_estimator, x_bins, x_ci, scatter, fit_reg, ci, n_boot, units, order, logistic, lowess, robust, logx, x_partial, y_partial, truncate, dropna, x_jitter, y_jitter, label, color, marker, scatter_kws, line_kws, ax)
777 scatter_kws["marker"] = marker
778 line_kws = {} if line_kws is None else copy.copy(line_kws)
--> 779 plotter.plot(ax, scatter_kws, line_kws)
780 return ax
781
/usr/local/lib/python2.7/dist-packages/seaborn/linearmodels.pyc in plot(self, ax, scatter_kws, line_kws)
328 # Draw the constituent plots
329 if self.scatter:
--> 330 self.scatterplot(ax, scatter_kws)
331 if self.fit_reg:
332 self.lineplot(ax, line_kws)
/usr/local/lib/python2.7/dist-packages/seaborn/linearmodels.pyc in scatterplot(self, ax, kws)
353 kws.setdefault("linewidths", lw)
354
--> 355 if not hasattr(kws['color'], 'shape') or kws['color'].shape[1] < 4:
356 kws.setdefault("alpha", .8)
357
IndexError: tuple index out of range
在这种情况下,在分配颜色和cmap时我做错了什么?谢谢!
答案 0 :(得分:0)
我自己就遇到过这个问题,代码在一年前就有效了。 (我可能已经从Python 2切换到Python 3,这可能解释了错误。)
我挖了一下代码,正如你所说,错误发生在
--> 355 if not hasattr(kws['color'], 'shape') or kws['color'].shape[1] < 4:
356 kws.setdefault("alpha", .8)
357
IndexError: tuple index out of range
如果您查看此处发生的情况,无论您传入'color'
关键字(在您的情况下为'color':col
),都需要 具有以下特点:
shape
属性shape
属性,则该属性必须至少有2个维度。嗯,有问题: pandas Series
或 numpy ndarray
(或其他几个数据结构,我猜),具有shape
属性,只能有一个维度。
例如,当我遇到问题时,我有类似以下内容:
col.shape
(2506,)
这意味着我的col
变量(在我的情况下, pandas Series
对象),两者都有shape
和形状只有一个维度。
我不明白如何解决这个问题。我试图强迫我的 pandas Series
加入list
,但这并没有解决问题。我试图传递一个2D pandas DataFrame
,其中每列相同,但没有修复它。
在查看source code时,对我来说,如何解决问题并不明显。似乎右修复,可能是在第355行添加另一个看起来像这样的检查:
355 if not hasattr(kws['color'], 'shape') or len(kws['color'].shape) < 2 or kws['color'].shape[1] < 4:
但我没有精力(或时间)来解决分配源和提交修复的麻烦。 :(