将图例添加到动态颜色的matplotlib散点图中

时间:2020-06-24 09:13:12

标签: python python-3.x matplotlib

我使用Matplolib和Pandas数据框创建了一个散点图,现在我想为其添加图例。这是我的代码:

colors = ['red' if x >= 150 and x < 200 else 
          'green' if x >= 200 and x < 400 else
          'purple' if x >= 400 and x < 600 else
          'yellow' if x >= 600 else 'teal' for x in myData.R]


ax1.scatter(myData.X, myData.Y, s=20, c=colors, marker='_', label='Test')
ax1.legend(loc='upper left', frameon=False)

这里发生的是,根据myData.R的值,散点图中点的颜色将改变。因此,由于颜色是“动态”的,因此在创建图例时遇到很多麻烦。实际的代码只会创建带有单个标签“ Test”的图例,并且附近没有任何颜色。

以下是数据示例:

       X  Y    R
0      1  945  1236.334519
0      1  950   212.809352
0      1  950   290.663847
0      1  961   158.156856

我尝试了this,但我不明白的是:

  1. 如何动态为图例设置标签?例如,我的代码显示为'red' if x >= 150,因此在图例上应该有一个红色正方形,其附近是> 150 。但是由于我没有手动添加任何标签,因此在理解这一点时遇到了麻烦。

  2. 尝试以下操作后,我只得到了带有单个标签“ Classes”的图例:

`legend1 = ax1.legend(* scatter.legend_elements(), loc =“左下”,title =“ Classs”)

ax1.add_artist(legend1)`

任何建议都值得赞赏!

1 个答案:

答案 0 :(得分:2)

可以加速的一部分代码是使用普通的Python循环创建字符串列表。 熊猫非常有效地使用了numpy的过滤。 绘制散点图主要取决于点的数量,当一次绘制所有点或将其分为五个部分时,不变。

在循环中使用matplotlib的散点图的一些示例代码:

from matplotlib import pyplot as plt
import numpy as np
import pandas as pd

N = 500
myData = pd.DataFrame({'X': np.round(np.random.uniform(-1000, 1000, N), -2), 
                       'Y': np.random.uniform(-800, 800, N)})
myData['R'] = np.sqrt(myData.X ** 2 + myData.Y ** 2)

fig, ax1 = plt.subplots()

bounds = [150, 200, 400, 600]
colors = ['teal', 'red', 'green', 'purple', 'gold']
for b0, b1, col in zip([None]+bounds, bounds+[None], colors):
    if b0 is None:
        filter = (myData.R < b1)
        label = f'$ R < {b1} $'
    elif b1 is None:
        filter = (myData.R >= b0)
        label = f'${b0} ≤ R $'
    else:
        filter = (myData.R >= b0) & (myData.R < b1)
        label = f'${b0} ≤ R < {b1}$'
    ax1.scatter(myData.X[filter], myData.Y[filter], s=20, c=col, marker='_', label=label)
ax1.legend()
plt.show()

matplotlib's scatterplot

或者,熊猫的cut可用于创建类别,而Seaborn的功能(如其hue参数则可进行着色并自动创建图例。

from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

N = 500
myData = pd.DataFrame({'X': np.round( np.random.uniform(-1000, 1000, N),-2), 'Y': np.random.uniform(-800, 800, N)})
myData['R'] = np.sqrt(myData.X ** 2 + myData.Y ** 2)

fig, ax1 = plt.subplots()

bounds = [150, 200, 400, 600]
colors = ['teal', 'red', 'green', 'purple', 'gold']

hues = pd.cut(myData.R, [0]+bounds+[2000], right=False)
sns.scatterplot(myData.X, myData.Y, hue=hues, hue_order=hues.cat.categories, palette=colors, s=20, marker='_', ax=ax1)
plt.show()

seaborn scatterplot