使用matplotlib并排绘制图像

时间:2017-01-22 17:24:29

标签: python matplotlib

我想知道如何使用matplotlib并排绘制图像,例如:

enter image description here

我最接近的是:

enter image description here

这是使用此代码生成的:

f, axarr = plt.subplots(2,2)
axarr[0,0] = plt.imshow(image_datas[0])
axarr[0,1] = plt.imshow(image_datas[1])
axarr[1,0] = plt.imshow(image_datas[2])
axarr[1,1] = plt.imshow(image_datas[3])

但我似乎无法展示其他图片。我认为必须有一个更好的方法来做到这一点,因为我想象试图管理索引会很痛苦。我已经浏览了documentation,虽然我有一种感觉,我可能会看错了。有人能够给我一个例子或指出我正确的方向吗?

7 个答案:

答案 0 :(得分:37)

您遇到的问题是,您尝试分配 imshowmatplotlib.image.AxesImage对现有轴对象的返回。

将图像数据绘制到axarr中不同轴的正确方法是

f, axarr = plt.subplots(2,2)
axarr[0,0].imshow(image_datas[0])
axarr[0,1].imshow(image_datas[1])
axarr[1,0].imshow(image_datas[2])
axarr[1,1].imshow(image_datas[3])

所有子图的概念都是相同的,并且在大多数情况下,轴实例提供的方法与pyplot(plt)接口相同。 例如。如果ax是您的子情节轴之一,为了绘制法线图,您可以使用ax.plot(..)代替plt.plot()。实际上,这可以在the page you link to的来源中找到。

答案 1 :(得分:13)

您正在一个轴上绘制所有图像。您想要的是单独获取每个轴的手柄并在那里绘制图像。像这样:

fig = plt.figure()
ax1 = fig.add_subplot(2,2,1)
ax1.imshow(...)
ax2 = fig.add_subplot(2,2,2)
ax2.imshow(...)
ax3 = fig.add_subplot(2,2,3)
ax3.imshow(...)
ax4 = fig.add_subplot(2,2,4)
ax4.imshow(...)

有关详细信息,请查看此处:http://matplotlib.org/examples/pylab_examples/subplots_demo.html

对于复杂的布局,您应该考虑使用gridspec:http://matplotlib.org/users/gridspec.html

答案 2 :(得分:13)

如果图像位于数组中,并且要遍历每个元素并打印,则可以编写如下代码:

# set on both throw and catch
breakpoint set -E C++ -h true
# or on catch
b __cxa_begin_catch
# or on throw
b __cxa_throw

还要注意,我使用了子图而不是子图。他们俩都不一样

答案 3 :(得分:5)

我发现可以使用的一件事很有帮助:

_, axs = plt.subplots(n_row, n_col, figsize=(12, 12))
axs = axs.flatten()
for img, ax in zip(imgs, axs):
    ax.imshow(img)
    plt.show()

答案 4 :(得分:1)

我大约每周访问一次这个网址。对于那些想要一个可以轻松绘制图像网格的小功能的人,我们开始:

import matplotlib.pyplot as plt
import numpy as np

def plot_image_grid(images, ncols=None, cmap='gray'):
    '''Plot a grid of images'''
    if not ncols:
        factors = [i for i in range(1, len(images)+1) if len(images) % i == 0]
        ncols = factors[len(factors) // 2] if len(factors) else len(images) // 4 + 1
    nrows = int(len(images) / ncols) + int(len(images) % ncols)
    imgs = [images[i] if len(images) > i else None for i in range(nrows * ncols)]
    f, axes = plt.subplots(nrows, ncols, figsize=(3*ncols, 2*nrows))
    axes = axes.flatten()[:len(imgs)]
    for img, ax in zip(imgs, axes.flatten()): 
        if np.any(img):
            if len(img.shape) > 2 and img.shape[2] == 1:
                img = img.squeeze()
            ax.imshow(img, cmap=cmap)

# make 16 images with 60 height, 80 width, 3 color channels
images = np.random.rand(16, 60, 80, 3)

# plot them
plot_image_grid(images)

答案 5 :(得分:0)

根据matplotlib's suggestion for image grids

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

fig = plt.figure(figsize=(4., 4.))
grid = ImageGrid(fig, 111,  # similar to subplot(111)
                 nrows_ncols=(2, 2),  # creates 2x2 grid of axes
                 axes_pad=0.1,  # pad between axes in inch.
                 )

for ax, im in zip(grid, image_data):
    # Iterating over the grid returns the Axes.
    ax.imshow(im)

plt.show()

答案 6 :(得分:0)

以下是在网格中并排显示图像的完整代码。您可以使用不同的参数调用函数 show_image_list()

  1. 传入图像列表,其中每个图像都是一个 Numpy 数组。默认情况下,它将创建一个包含 2 列的网格。它还会推断每个图像是彩色还是灰度。
list_images = [img, gradx, grady, mag_binary, dir_binary]

show_image_list(list_images, figsize=(10, 10))

enter image description here

  1. 传入图像列表、每个图像的标题列表和其他参数。
show_image_list(list_images=[img, gradx, grady, mag_binary, dir_binary], 
                list_titles=['original', 'gradx', 'grady', 'mag_binary', 'dir_binary'],
                num_cols=3,
                figsize=(20, 10),
                grid=False,
                title_fontsize=20)

enter image description here

代码如下:

import matplotlib.pyplot as plt
import numpy as np

def img_is_color(img):

    if len(img.shape) == 3:
        # Check the color channels to see if they're all the same.
        c1, c2, c3 = img[:, : , 0], img[:, :, 1], img[:, :, 2]
        if (c1 == c2).all() and (c2 == c3).all():
            return True

    return False

def show_image_list(list_images, list_titles=None, list_cmaps=None, grid=True, num_cols=2, figsize=(20, 10), title_fontsize=30):
    '''
    Shows a grid of images, where each image is a Numpy array. The images can be either
    RGB or grayscale.

    Parameters:
    ----------
    images: list
        List of the images to be displayed.
    list_titles: list or None
        Optional list of titles to be shown for each image.
    list_cmaps: list or None
        Optional list of cmap values for each image. If None, then cmap will be
        automatically inferred.
    grid: boolean
        If True, show a grid over each image
    num_cols: int
        Number of columns to show.
    figsize: tuple of width, height
        Value to be passed to pyplot.figure()
    title_fontsize: int
        Value to be passed to set_title().
    '''

    assert isinstance(list_images, list)
    assert len(list_images) > 0
    assert isinstance(list_images[0], np.ndarray)

    if list_titles is not None:
        assert isinstance(list_titles, list)
        assert len(list_images) == len(list_titles), '%d imgs != %d titles' % (len(list_images), len(list_titles))

    if list_cmaps is not None:
        assert isinstance(list_cmaps, list)
        assert len(list_images) == len(list_cmaps), '%d imgs != %d cmaps' % (len(list_images), len(list_cmaps))

    num_images  = len(list_images)
    num_cols    = min(num_images, num_cols)
    num_rows    = int(num_images / num_cols) + (1 if num_images % num_cols != 0 else 0)

    # Create a grid of subplots.
    fig, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
    
    # Create list of axes for easy iteration.
    list_axes = list(axes.flat)

    for i in range(num_images):

        img    = list_images[i]
        title  = list_titles[i] if list_titles is not None else 'Image %d' % (i)
        cmap   = list_cmaps[i] if list_cmaps is not None else (None if img_is_color(img) else 'gray')
        
        list_axes[i].imshow(img, cmap=cmap)
        list_axes[i].set_title(title, fontsize=title_fontsize) 
        list_axes[i].grid(grid)

    for i in range(num_images, len(list_axes)):
        list_axes[i].set_visible(False)

    fig.tight_layout()
    _ = plt.show()