pyplot.imshow为矩形

时间:2016-07-12 08:30:32

标签: python matplotlib rectangles imshow

我目前正在使用matplotlib.pyplot来显示某些2D数据:

from matplotlib import pyplot as plt
import numpy as np
A=np.matrix("1 2 1;3 0 3;1 2 0") # 3x3 matrix with 2D data
plt.imshow(A, interpolation="nearest") # draws one square per matrix entry
plt.show()

现在我将数据从正方形移动到矩形,这意味着我有两个额外的数组,例如:

grid_x = np.array([0.0, 1.0, 4.0, 5.0]) # points on the x-axis
grid_x = np.array([0.0, 2.5, 4.0, 5.0]) # points on the y-axis

现在我想要一个带矩形的网格:

  • 左上角:(grid_x[i], grid_y[j])
  • 右下角:(grid_x[i+1], grid_y[j+1])
  • 数据(颜色):A[i,j]

在新网格上绘制数据的简便方法是什么? imshow似乎是可用的,我看了pcolormesh,但它与网格混淆为2D数组,使用两个矩阵,如np.mgrid[0:5:0.5,0:5:0.5]用于常规网格,并为不规则网格构建类似的东西

用于可视化矩形的简单方式是什么?

2 个答案:

答案 0 :(得分:1)

import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import matplotlib.cm as cm
from matplotlib.collections import PatchCollection
import numpy as np

A = np.matrix("1 2 1;3 0 3;1 2 0;4 1 2") # 4x3 matrix with 2D data

grid_x0 = np.array([0.0, 1.0, 4.0, 6.7])
grid_y0 = np.array([0.0, 2.5, 4.0, 7.8, 12.4])

grid_x1, grid_y1 = np.meshgrid(grid_x0, grid_y0)
grid_x2 = grid_x1[:-1, :-1].flat
grid_y2 = grid_y1[:-1, :-1].flat
widths = np.tile(np.diff(grid_x0)[np.newaxis], (len(grid_y0)-1, 1)).flat
heights = np.tile(np.diff(grid_y0)[np.newaxis].T, (1, len(grid_x0)-1)).flat

fig = plt.figure()
ax = fig.add_subplot(111)
ptchs = []
for x0, y0, w, h in zip(grid_x2, grid_y2, widths, heights):
    ptchs.append(Rectangle(
        (x0, y0), w, h,
    ))
p = PatchCollection(ptchs, cmap=cm.viridis, alpha=0.4)
p.set_array(np.ravel(A))
ax.add_collection(p)
plt.xlim([0, 8])
plt.ylim([0, 13])
plt.show()

enter image description here

以下是另一种方法,使用图片和R树以及imshowcolorbar,您需要更改x-ticksy-ticks(还有很多SO Q& ;关于如何做到这一点。)

from rtree import index
import matplotlib.pyplot as plt
import numpy as np

eps = 1e-3

A = np.matrix("1 2 1;3 0 3;1 2 0;4 1 2") # 4x3 matrix with 2D data
grid_x0 = np.array([0.0, 1.0, 4.0, 6.7])
grid_y0 = np.array([0.0, 2.5, 4.0, 7.8, 12.4])

grid_x1, grid_y1 = np.meshgrid(grid_x0, grid_y0)
grid_x2 = grid_x1[:-1, :-1].flat
grid_y2 = grid_y1[:-1, :-1].flat
grid_x3 = grid_x1[1:, 1:].flat
grid_y3 = grid_y1[1:, 1:].flat

fig = plt.figure()

rows = 100
cols = 200
im = np.zeros((rows, cols), dtype=np.int8)
grid_j = np.linspace(grid_x0[0], grid_x0[-1], cols)
grid_i = np.linspace(grid_y0[0], grid_y0[-1], rows)
j, i = np.meshgrid(grid_j, grid_i)

i = i.flat
j = j.flat

idx = index.Index()

for m, (x0, y0, x1, y1) in enumerate(zip(grid_x2, grid_y2, grid_x3, grid_y3)):
    idx.insert(m, (x0, y0, x1, y1))


for k, (i0, j0) in enumerate(zip(i, j)):
    ind = next(idx.intersection((j0-eps, i0-eps, j0+eps, i0+eps)))

    im[np.unravel_index(k, im.shape)] = A[np.unravel_index(ind, A.shape)]
plt.imshow(im)
plt.colorbar()
plt.show()

enter image description here

答案 1 :(得分:1)

这是一个基于@ ophir-carmi代码的可重用函数:

import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.collections import PatchCollection
import itertools
import numpy as np

def gridshow(grid_x, grid_y, data, **kwargs):
    vmin = kwargs.pop("vmin", None)
    vmax = kwargs.pop("vmax", None)
    data = np.array(data).reshape(-1)
    # there should be data for (n-1)x(m-1) cells
    assert (grid_x.shape[0] - 1) * (grid_y.shape[0] - 1) == data.shape[0], "Wrong number of data points. grid_x=%s, grid_y=%s, data=%s" % (grid_x.shape, grid_y.shape, data.shape)
    ptchs = []
    for j, i in itertools.product(xrange(len(grid_y) - 1), xrange(len(grid_x) - 1)):
        xy = grid_x[i], grid_y[j]
        width = grid_x[i+1] - grid_x[i]
        height = grid_y[j+1] - grid_y[j]
        ptchs.append(Rectangle(xy=xy, width=width, height=height, rasterized=True, linewidth=0, linestyle="None"))
    p = PatchCollection(ptchs, linewidth=0, **kwargs)
    p.set_array(np.array(data))
    p.set_clim(vmin, vmax)
    ax = plt.gca()
    ax.set_aspect("equal")
    plt.xlim([grid_x[0], grid_x[-1]])
    plt.ylim([grid_y[0], grid_y[-1]])
    ret = ax.add_collection(p)
    plt.sci(ret)
    return ret

if __name__ == "__main__":
    grid_x = np.linspace(0, 20, 21) + np.random.randn(21)/5.0
    grid_y = np.linspace(0, 18, 19) + np.random.randn(19)/5.0
    grid_x = np.round(grid_x, 2)
    grid_y = np.round(grid_y, 2)
    data = np.random.randn((grid_x.shape[0] -1) * (grid_y.shape[0] -1))

    fig = plt.figure()
    ax = fig.add_subplot(111)
    gridshow(grid_x, grid_y, data, alpha=1.0)
    plt.savefig("test.png")

example image from the <code>gridshow</code> function

我不完全确定大型网格的效果,以及**kwargs是否应该应用于PatchCollection。在一些矩形之间似乎是1px的空白,可能是由于不良的舍入。也许dx, width, height需要一致floor / ceil到下一个完整像素。

使用rtreeimshow的另一种解决方案:

import matplotlib.pyplot as plt
import numpy as np
from rtree import index

def gridshow(grid_x, grid_y, data, rows=200, cols=200, eps=1e-3, **kwargs):
    grid_x1, grid_y1 = np.meshgrid(grid_x, grid_y)
    grid_x2 = grid_x1[:-1, :-1].flat
    grid_y2 = grid_y1[:-1, :-1].flat
    grid_x3 = grid_x1[1:, 1:].flat
    grid_y3 = grid_y1[1:, 1:].flat

    grid_j = np.linspace(grid_x[0], grid_x[-1], cols)
    grid_i = np.linspace(grid_y[0], grid_y[-1], rows)
    j, i = np.meshgrid(grid_j, grid_i)
    i = i.flat
    j = j.flat

    im = np.empty((rows, cols), dtype=np.float64)

    idx = index.Index()
    for m, (x0, y0, x1, y1) in enumerate(zip(grid_x2, grid_y2, grid_x3, grid_y3)):
        idx.insert(m, (x0, y0, x1, y1))
    for k, (i0, j0) in enumerate(zip(i, j)):
        ind = next(idx.intersection((j0-eps, i0-eps, j0+eps, i0+eps)))
        im[np.unravel_index(k, im.shape)] = data[np.unravel_index(ind, data.shape)]

    fig = plt.gca()
    return plt.imshow(im, interpolation="nearest")


if __name__ == "__main__":
    grid_x = np.linspace(0, 200, 201) + np.random.randn(201)/5.0
    grid_y = np.linspace(0, 108, 109) + np.random.randn(109)/5.0
    grid_x = np.round(grid_x, 2)
    grid_y = np.round(grid_y, 2)
    data = np.random.randn((grid_x.shape[0] -1) * (grid_y.shape[0] -1))

    fig = plt.figure()
    ax = fig.add_subplot(111)
    gridshow(grid_x, grid_y, data, alpha=1.0)
    plt.savefig("test.png")

image with imshow and rtree