如何根据图像上的颜色绘制图例?

时间:2017-07-07 06:01:17

标签: python matplotlib legend

现在我得到了像fig1这样的图像,不同的颜色意味着不同的东西,我想在fig1的底部添加一个图例(图2),怎么做?我有每种颜色的rgb值。

图1: enter image description here

图2: enter image description here

这是我得到的代码:

# coding=utf-8
import matplotlib
matplotlib.use('Agg')
import h5py
import numpy
from PIL import Image
from PIL import ImageDraw
import matplotlib.pyplot as plt
import matplotlib.ticker as plticker
import sys
table={
    k:v for k,v,n in [

[
        127,
        [
            100,
            100,
            100
        ],
        "NO DATA"
    ],
    [
        126,
        [
            0,
            0,
            0
        ],
        "SPACE"
    ],
    [
        0,
        [
            200,
            255,
            255
        ],
        "CLEAR"
    ],
    [
        2,
        [
            0,
            0,
            244
        ],
        "WATER CLOUD"
    ],
    [
        3,
        [
            32,
            165,
            225
        ],
        "ICED CLOUD"
    ],
    [
        4,
        [
            33,
            255,
            170
        ],
        "MIXED CLOUD"
    ],
    [
        5,
        [
            255,
            0,
            0
        ],
        "CIRRUS CLOUD"
    ],
    [
        6,
        [
            180,
            20,
            255
        ],
        "Opaque cloud"
    ],
    [
        7,
        [
            105,
            255,
            0
        ],
        "OVERLAP CLOUD"
    ],
    [
        9,
        [
            224,
            180,
            0
        ],
        "UNKNOWN CLOUD"
    ]
]
}


def main(_,fn,out):
    with h5py.File(fn) as f:
        data = f['EVB1'].value
    w,h = data.shape
    ret = numpy.zeros((w,h,3),'u1')
    for i in (0,2,3,4,5,6,7,9,126,127):
        ret[data==i]=table[i]

    Image.fromarray(ret,mode="RGB").save(out)
    image = Image.open(out)
    my_dpi = 100.


    # Set up figure
    fig = plt.figure(figsize=(float(image.size[0]) / my_dpi,float(image.size[1]) / my_dpi), dpi=my_dpi)
    ax = fig.add_subplot(111)

# Set the gridding interval: here we use the major tick interval
    myInterval = 249.9
    loc = plticker.MultipleLocator(base=myInterval)
    # ax=plt.gca()
    ax.xaxis.set_major_locator(loc)
    ax.yaxis.set_major_locator(loc)

    ax.set_xticklabels(['60', '70', '80', '90', '100', '110', '120', '130', '140'])
# ax.set_xticklabels(np.arange(70,150,10))
    ax.set_yticklabels(('70', '60', '50', '40', '30', '20', '10', '0'))
#


    out1 = out.split('/')[-1].split('.')[0].split('V0001')[0]

    ax.set_title(out1,fontsize = 20)

# Add the grid
    ax.grid(which='major', axis='both', linestyle='-')

# Add the image
    ax.imshow(image)


# Save the figure
    fig.savefig(out)



if __name__ == '__main__':
    main(*sys.argv)

2 个答案:

答案 0 :(得分:1)

我无法正确显示中文字符,但你应该得到基本的想法:

# coding=utf-8
from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle
import numpy as np

x = np.linspace(0,1,100)
y = np.linspace(0,1,100)

X,Y = np.meshgrid(x,y)

array = np.sin(X)*np.cos(Y)

plt.imshow(array)

legend_data = [
   [
    127,
    [
        100,
        100,
        100
    ],
    "无数据区"
],
[
    126,
    [
        0,
        0,
        0
    ],
    "外太空"
],
[
    0,
    [
        200,
        255,
        255
    ],
    "晴空"
],
[
    2,
    [
        0,
        0,
        244
    ],
    "水云"
],
[
    3,
    [
        32,
        165,
        225
    ],
    "过冷水云"
],
[
    4,
    [
        33,
        255,
        170
    ],
    "混合云"
],
[
    5,
    [
        255,
        0,
        0
    ],
    "厚冰云"
],
[
    6,
    [
        180,
        20,
        255
    ],
    "卷云"
],
[
    7,
    [
        105,
        255,
        0
    ],
    "多层云"
],
[
    9,
    [
        224,
        180,
        0
    ],
    "不确定"
]
]    
handles = [
    Rectangle((0,0),1,1, color = (v/255 for v in c)) for k,c,n in legend_data
]
labels = [n for k,c,n in legend_data]

plt.legend(handles,labels)
plt.show()

结果如下:

result of the code above

情节只是占位符,因为我没有输入数据。关键的行是从表中生成矩形handleslabels以及最后生成标签命令的行。

修改

如果您希望图例严格低于图表,可以通过为图例定义第二个轴来实现:

from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.gridspec import GridSpec
import numpy as np

from matplotlib.font_manager import FontProperties
ChineseFont = FontProperties('SimHei')

x = np.linspace(0,1,100)
y = np.linspace(0,1,100)

X,Y = np.meshgrid(x,y)

array = np.sin(X)*np.cos(Y)

gs = GridSpec(6,1)

fig = plt.figure(figsize = (4,6))
ax1 = fig.add_subplot(gs[:-1,:]) ##for the plot
ax2 = fig.add_subplot(gs[-1,:])   ##for the legend

ax1.imshow(array)

legend_data =[
[
        127,
        [
            100,
            100,
            100
        ],
        u"无数据区"
    ],
...
]
handles = [
    Rectangle((0,0),1,1, color = tuple((v/255 for v in c))) for k,c,n in legend_data
]
labels = [n for k,c,n in legend_data]

ax2.legend(handles,labels, mode='expand', ncol=3, prop=ChineseFont)
ax2.axis('off')
plt.show()

这看起来像这样:

fully working example output

<强> EDIT2

我找到了一种在this answer的帮助下正确显示汉字的方法。该示例现在应该在Python 2.7和Python 3.5中使用 - 只需在每个标签前加u并除以255.0而不是255

答案 1 :(得分:0)

我进行了一些更改以使其更实用。

通常,我们不知道照片上总共有哪种颜色。

所以我使用KMeans对图像颜色进行分类。

from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle
import matplotlib
import numpy as np
import pandas as pd

matplotlib.rcParams['font.sans-serif'] = ['MingLiu']  # 細明體

y = x = np.linspace(0, 1, 100)
X, Y = np.meshgrid(x, y)
array = np.sin(X) * np.cos(Y)
USING_ONLINE_IMAGE_FLAG = input('load image from online?(y/n)').upper() == 'Y'
if USING_ONLINE_IMAGE_FLAG:  # 'meteorological.png'
    from sklearn import cluster
    from urllib.request import urlopen
    from PIL import Image

    USING_ONLINE_IMAGE_FLAG = True

    # read image from online
    url_meteorological_img = "https://i.stack.imgur.com/OOIw5.png"
    image = Image.open(urlopen(url_meteorological_img))  # imageio.imread('meteorological.png', as_gray=False, pilmode="RGB")
    image = np.asarray(image)
    x, y, z = image.shape
    image_2d = image.reshape(x * y, z)

    # cluster different regions of the image.
    kmeans_cluster = cluster.KMeans(n_clusters=10)
    kmeans_cluster.fit(image_2d)
    cluster_centers = kmeans_cluster.cluster_centers_
    cluster_labels = kmeans_cluster.labels_
    picture_data = cluster_centers[cluster_labels].reshape(x, y, z).astype(int)
    cluster_colors = set(map(tuple, picture_data.reshape(-1, 3)))
    print(cluster_colors)
    array = picture_data

df_legend = pd.DataFrame([[127, [100, 100, 100], '无数据区'],
                          [126, [0, 0, 0], '外太空'],
                          [0, [200, 255, 255], '晴空'],
                          [2, [0, 0, 244], '水云'],
                          [3, [32, 165, 225], '过冷水云'],
                          [4, [33, 255, 170], '混合云'],
                          [5, [255, 0, 0], '厚冰云'],
                          [6, [180, 20, 255], '卷云'],
                          [7, [105, 255, 0], '多层云'],
                          [9, [224, 180, 0], '不确定']],
                         columns=['key', 'color', 'name'])

handles_1 = [Rectangle((0, 0), 1, 1, color=[c / 255 for c in color_list]) for color_list in df_legend['color']]
handles_2 = [Rectangle((0, 0), 1, 1, color=[c / 255 for c in color_list]) for color_list in cluster_colors] if USING_ONLINE_IMAGE_FLAG else None
labels = df_legend['name']
plt.figure(figsize=(3, 1))
plt.subplots_adjust(hspace=0)  # plt.tight_layout()
plt.rcParams.update({'legend.fontsize': 26})
plt.rc(('xtick', 'ytick'), color=(1, 1, 1, 0))
plt.subplot(3, 1, 1), plt.imshow(array, aspect='auto')
plt.subplot(3, 1, 2), plt.legend(handles_1, labels, mode='expand', ncol=3)
plt.subplot(3, 1, 3), plt.legend(handles_2, cluster_colors, mode='expand', ncol=3) if USING_ONLINE_IMAGE_FLAG else None
plt.show()

https://i.stack.imgur.com/OOIw5.png enter image description here

演示enter image description here