我使用Matplolib和Pandas数据框创建了一个散点图,现在我想为其添加图例。这是我的代码:
colors = ['red' if x >= 150 and x < 200 else
'green' if x >= 200 and x < 400 else
'purple' if x >= 400 and x < 600 else
'yellow' if x >= 600 else 'teal' for x in myData.R]
ax1.scatter(myData.X, myData.Y, s=20, c=colors, marker='_', label='Test')
ax1.legend(loc='upper left', frameon=False)
这里发生的是,根据myData.R
的值,散点图中点的颜色将改变。因此,由于颜色是“动态”的,因此在创建图例时遇到很多麻烦。实际的代码只会创建带有单个标签“ Test”的图例,并且附近没有任何颜色。
以下是数据示例:
X Y R
0 1 945 1236.334519
0 1 950 212.809352
0 1 950 290.663847
0 1 961 158.156856
我尝试了this,但我不明白的是:
如何动态为图例设置标签?例如,我的代码显示为'red' if x >= 150
,因此在图例上应该有一个红色正方形,其附近是> 150 。但是由于我没有手动添加任何标签,因此在理解这一点时遇到了麻烦。
尝试以下操作后,我只得到了带有单个标签“ Classes”的图例:
`legend1 = ax1.legend(* scatter.legend_elements(), loc =“左下”,title =“ Classs”)
ax1.add_artist(legend1)`
任何建议都值得赞赏!
答案 0 :(得分:2)
可以加速的一部分代码是使用普通的Python循环创建字符串列表。 熊猫非常有效地使用了numpy的过滤。 绘制散点图主要取决于点的数量,当一次绘制所有点或将其分为五个部分时,不变。
在循环中使用matplotlib的散点图的一些示例代码:
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
N = 500
myData = pd.DataFrame({'X': np.round(np.random.uniform(-1000, 1000, N), -2),
'Y': np.random.uniform(-800, 800, N)})
myData['R'] = np.sqrt(myData.X ** 2 + myData.Y ** 2)
fig, ax1 = plt.subplots()
bounds = [150, 200, 400, 600]
colors = ['teal', 'red', 'green', 'purple', 'gold']
for b0, b1, col in zip([None]+bounds, bounds+[None], colors):
if b0 is None:
filter = (myData.R < b1)
label = f'$ R < {b1} $'
elif b1 is None:
filter = (myData.R >= b0)
label = f'${b0} ≤ R $'
else:
filter = (myData.R >= b0) & (myData.R < b1)
label = f'${b0} ≤ R < {b1}$'
ax1.scatter(myData.X[filter], myData.Y[filter], s=20, c=col, marker='_', label=label)
ax1.legend()
plt.show()
或者,熊猫的cut
可用于创建类别,而Seaborn的功能(如其hue
参数则可进行着色并自动创建图例。
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
N = 500
myData = pd.DataFrame({'X': np.round( np.random.uniform(-1000, 1000, N),-2), 'Y': np.random.uniform(-800, 800, N)})
myData['R'] = np.sqrt(myData.X ** 2 + myData.Y ** 2)
fig, ax1 = plt.subplots()
bounds = [150, 200, 400, 600]
colors = ['teal', 'red', 'green', 'purple', 'gold']
hues = pd.cut(myData.R, [0]+bounds+[2000], right=False)
sns.scatterplot(myData.X, myData.Y, hue=hues, hue_order=hues.cat.categories, palette=colors, s=20, marker='_', ax=ax1)
plt.show()