Pyplot散点图图例不适用于较小的样本量

时间:2017-04-19 18:01:47

标签: python matplotlib plot

我正在使用下面的代码在pyplot中生成一个散点图,我想让9个类中的每一个以不同的颜色绘制。每个班级都有多个要点。

我无法弄清楚为什么图例不适用于较小的样本量。

def plot_scatter_test(x, y, c, title):
    data = pd.DataFrame({'x': x, 'y': y, 'c': c})
    classes = len(np.unique(c))
    colors = cm.rainbow(np.linspace(0, 1, classes))

    ax = plt.subplot(111)
    for s in range(0,classes):
        ss = data[data['c']==s]
        plt.scatter(x=ss['x'], y=ss['y'],c=colors[s], label=s)

    ax.legend(loc='lower left',scatterpoints=1, ncol=3, fontsize=8, bbox_to_anchor=(0, -.4), title='Legend')
    plt.show()

我的数据看起来像这样

Data

当我通过调用

来绘制它时
plot_scatter_test(test['x'], test['y'],test['group'])

我在图表中获得了不同的颜色,但图例是单色

Chart 1

因此,为了确保我的数据正常,我使用相同类型的数据创建了一个随机数据帧。现在我得到了不同的颜色,但有些东西仍然是错误的,因为它们不是连续的。

test2 = pd.DataFrame({
    'y': np.random.uniform(0,1400,36),
    'x': np.random.uniform(-250,-220,36),
    'group': np.random.randint(0,9,36)
})
plot_scatter_test(test2['x'], test2['y'],test2['group'])

colors2

最后,我创建了一个更大的360数据点图,一切看起来都像我期望的那样。我做错了什么?

test3 = pd.DataFrame({
    'y': np.random.uniform(0,1400,360),
    'x': np.random.uniform(-250,-220,360),
    'group': np.random.randint(0,9,360)
})

plot_scatter_test(test3['x'], test3['y'],test3['group'])

chart3

2 个答案:

答案 0 :(得分:1)

您需要确保不要将类本身与用于编制索引的编号混淆。

为了更好地观察我的意思,请在函数中使用以下数据集:

np.random.seed(22)
X,Y= np.meshgrid(np.arange(3,7), np.arange(4,8))
test2 = pd.DataFrame({
    'y': Y.flatten(),
    'x': X.flatten(),
    'group': np.random.randint(0,9,len(X.flatten()))
})
plot_scatter_test(test2['x'], test2['y'],test2['group'])

导致以下图表,其中缺少点。

enter image description here

因此,明确区分索引和类,例如如下

import numpy as np; np.random.seed(22)
import matplotlib.pyplot as plt
import pandas as pd

def plot_scatter_test(x, y, c, title="title"):
    data = pd.DataFrame({'x': x, 'y': y, 'c': c})
    classes = np.unique(c)
    print classes
    colors = plt.cm.rainbow(np.linspace(0, 1, len(classes)))
    print colors
    ax = plt.subplot(111)
    for i, clas in enumerate(classes):
        ss = data[data['c']==clas]
        plt.scatter(ss["x"],ss["y"],c=[colors[i]]*len(ss), label=clas)

    ax.legend(loc='lower left',scatterpoints=1, ncol=3, fontsize=8,  title='Legend')
    plt.show()

X,Y= np.meshgrid(np.arange(3,7), np.arange(4,8))
test2 = pd.DataFrame({
    'y': Y.flatten(),
    'x': X.flatten(),
    'group': np.random.randint(0,9,len(X.flatten()))
})
plot_scatter_test(test2['x'], test2['y'],test2['group'])

enter image description here

除此之外,确实有必要不将颜色4元组直接提供给c,因为这会被解释为四种单色。

答案 1 :(得分:-1)

在盯着这一段时间之后,我觉得很傻。错误在于传递的颜色。我将单一颜色传递给.scatter函数。但是,由于有多个点,您需要传递相同数量的颜色。因此

plt.scatter(x=ss['x'], y=ss['y'],c=colors[s], label=s)

可以是

之类的东西
plt.scatter(x=ss['x'], y=ss['y'],c=[colors[s]]*len(ss), label=s)