如何创建基于多种颜色的图例?

时间:2019-06-22 03:34:30

标签: python matplotlib

我目前有系外行星日射强度与密度的关系图,其中不同的颜色导致了不同的轨道周期。我已经弄清楚了颜色情况,我只是对如何设置图例感到困惑。这就是我所拥有的。

plt.figure(figsize = (9,7))
plt.title('Insolation vs Density', fontsize = 24, 
fontweight='bold')
plt.xlabel('Density [g/cm$^3$]', fontsize = 16)
plt.ylabel('Insolation [Earth Flux]', fontsize=16)
plt.xscale('log')
plt.yscale('log')
x = data['Density [g/cm**3]']
y = data['Insolation [Earth Flux]']
z = data['Orbital Period']

def pltcolor(lst):
    cols=[]
    for i in data['Orbital Period']:
        if i <= 3:
            cols.append('mediumturquoise'),
        elif i >= 20 :
            cols.append('blue'),
        else:
            cols.append('crimson')
    return cols
cols=pltcolor(z)

plt.scatter(x=x,y=y,c=cols)
plt.scatter(circum_data['Density [g/cm**3]'],circum_data['Insolation [Earth Flux]'], color = 'fuchsia', label = 
Circumbinary Planets')
plt.legend();

2 个答案:

答案 0 :(得分:0)

您可以将图例标签(称为“标签”)放置在与“ x”向量长度相同的向量中,然后执行以下操作:

plt.legend(labels)

您也可以以不同的方式进行操作(我想这就是您要尝试的操作):您可以将每个组绘制成一个循环并设置标签。就像这样:

for i,group in enumerate(groups):
    plt.scatter(x[group],y[group],label=group_names[i])
plt.legend()

答案 1 :(得分:0)

根据我的理解,您需要为每个组呼叫plt.scatter。供参考,请看此question。现在,您要弄清楚每种数据点的颜色应该是什么,然后在cols中为其分配颜色。然后,您只需调用plt.scatter,它就会绘制所有点并相应地分配颜色。但是,matplotlib仍然认为所有这些观点都来自同一组。因此,当您致电plt.legend()时,它只会提供一个标签。

我已尝试使用您的代码为您解决。这有点棘手,因为您从示例中删除了数据(可以理解)。我假设您的数据是一个列表,因此创建了一些假数据来测试我的方法。

因此,我的方法如下:浏览数据,如果z数据位于特定范围内,则将其分配给新数组。处理完该组(z范围)的所有数据后,将其绘制出来。然后对每个组重复此操作。我在下面附上了我的想法示例。可能有更清洁的方法可以做到这一点。但是,总体方法是相同的。尝试分别绘制每个组。

import matplotlib.pyplot as plt
import math

# Fake data I created
data = {}
data['Density [g/cm**3]'] = [10,15, 31, 24,55]
data['Insolation [Earth Flux]'] = [10,15,8,4,55]
data['Orbital Period'] = [10,15,3,2,55]

circum_data = {}
circum_data['Density [g/cm**3]'] = [10,15,7,5,55]
circum_data['Insolation [Earth Flux]'] = [10,15,4,3,55]

# ----- Your code------
plt.figure(figsize = (9,7))
plt.title('Insolation vs Density', fontsize = 24, fontweight='bold')
plt.xlabel('Density [g/cm$^3$]', fontsize = 16)
plt.ylabel('Insolation [Earth Flux]', fontsize=16)
plt.xscale('log')
plt.yscale('log')
x = data['Density [g/cm**3]']
y = data['Insolation [Earth Flux]']
z = data['Orbital Period']
# -----------------

# Created the ranges you want
distances_max = [3, 20, math.inf]
distances_min = [-1*math.inf, 3, 20]
# Select you colors
colors = ['mediumturquoise', 'blue', 'crimson']
# The legend names you want
names = ['name1', 'name2', 'name3']

# For each of the ranges
for i in range(len(names)):
    # Create a new group of data
    col = []
    data_x = []
    data_y = []
    # Go through your data and put it in the group if it is inside the range
    for (xi, yi, zi) in zip(x, y, z): 
        if distances_min[i] < zi <= distances_max[i]:
            col.append(colors[i])
            data_x.append(xi) 
            data_y.append(yi) 

    # Plot the group of data
    plt.scatter(x=data_x,y=data_y,c=colors[i], label=names[i])   


# plt.scatter(circum_data['Density [g/cm**3]'],
#             circum_data['Insolation [Earth Flux]'],
#             color = 'fuchsia',
#             label ='Circumbinary Planets')
plt.legend()
plt.show()

运行此代码会产生以下输出,其中names列表中定义了name1,name2,name3。

Example graph

我希望这会有所帮助。祝你好运!