在Matplotlib中绘制两个图之间的线

时间:2013-07-09 08:22:15

标签: python matplotlib

我正在使用Matplotlib绘制两个子图,主要是:

subplot(211); imshow(a); scatter(..., ...)
subplot(212); imshow(b); scatter(..., ...)

我可以在这两个子图之间画线吗?我该怎么做?

3 个答案:

答案 0 :(得分:25)

您可以使用fig.line。它为你的身材增加了任何一条线。图线比轴线更高,因此您不需要任何轴来绘制它。

此示例在两个轴上标记相同的点。有必要小心坐标系,但转换会为你完成所有艰苦的工作。

import matplotlib.pyplot as plt
import matplotlib
import numpy as np

fig = plt.figure(figsize=(10,5))
ax1 = fig.add_subplot(121)
ax2 = fig.add_subplot(122)

x,y = np.random.rand(100),np.random.rand(100)

ax1.plot(x,y,'ko')
ax2.plot(x,y,'ko')

i = 10

transFigure = fig.transFigure.inverted()

coord1 = transFigure.transform(ax1.transData.transform([x[i],y[i]]))
coord2 = transFigure.transform(ax2.transData.transform([x[i],y[i]]))


line = matplotlib.lines.Line2D((coord1[0],coord2[0]),(coord1[1],coord2[1]),
                               transform=fig.transFigure)
fig.lines = line,

ax1.plot(x[i],y[i],'ro',markersize=20)
ax2.plot(x[i],y[i],'ro',markersize=20)


plt.show()

enter image description here

答案 1 :(得分:19)

其他答案的解决方案在许多情况下都不是最理想的(因为它们只有在计算点数后没有对绘图进行任何更改时才会起作用。)

更好的解决方案是使用专门设计的ConnectionPatch

import matplotlib.pyplot as plt
from matplotlib.patches import ConnectionPatch
import numpy as np

fig = plt.figure(figsize=(10,5))
ax1 = fig.add_subplot(121)
ax2 = fig.add_subplot(122)

x,y = np.random.rand(100),np.random.rand(100)

ax1.plot(x,y,'ko')
ax2.plot(x,y,'ko')

i = 10
xy = (x[i],y[i])
con = ConnectionPatch(xyA=xy, xyB=xy, coordsA="data", coordsB="data",
                      axesA=ax2, axesB=ax1, color="red")
ax2.add_artist(con)

ax1.plot(x[i],y[i],'ro',markersize=10)
ax2.plot(x[i],y[i],'ro',markersize=10)


plt.show()

enter image description here

答案 2 :(得分:1)

我不确定这是否正是您正在寻找的,但是一个简单的方法来绘制子图。

import matplotlib.pyplot as plt
import numpy as np

ax1=plt.figure(1).add_subplot(211)
ax2=plt.figure(1).add_subplot(212)

x_data=np.linspace(0,10,20)
ax1.plot(x_data, x_data**2,'o')
ax2.plot(x_data, x_data**3, 'o')

ax3 = plt.figure(1).add_subplot(111)
ax3.plot([5,5],[0,1],'--')
ax3.set_xlim([0,10])
ax3.axis("off")
plt.show()