通过matplotlib中的子图矩阵图绘制线条

时间:2017-03-28 19:47:51

标签: python-3.x matplotlib

我想在由不同数量的子图组成的图上绘制线条,看起来像3x3示例的红线。我怎么能在matplotlib中做到这一点?

在4D示例中,代码基本上是多维数据的2D投影数组(scatter matrix plot的右上半部分)。

tabular matrix

from matplotlib import pyplot as plt
import numpy as np
data = np.random.random_sample((10,4))
labels = ['p1','p2','p3','p4']
fig, axarr = plt.subplots(3,3, sharex='col', sharey='row')
# Iterate over rows of subplots array
for row in range(axarr.shape[0]):
    i = row # data index corresponds to row index
    # Iterate over columns of subplots array
    for col in range(axarr.shape[1]):
        j = col+1 # data index corresponds to column index +1
        # Do what's needed in lower-left half of array and leave
        if row>col:
            if col==0:
                axarr[row,col].set_ylabel(labels[i],labelpad=5)
            axarr[row,col].spines['left'].set_visible(False)
            axarr[row,col].spines['right'].set_visible(False)
            axarr[row,col].spines['bottom'].set_visible(False)
            axarr[row,col].spines['top'].set_visible(False)
            axarr[row,col].xaxis.set_ticks_position('none')
            axarr[row,col].yaxis.set_ticks_position('none')
            axarr[row,col].tick_params(labelleft=False)
            axarr[row,col].tick_params(labelbottom=False)
            continue
        # Proceed with upper-right half of array
        axarr[row,col].scatter(data[:,i],data[:,j], s=4)
        axarr[row,col].tick_params(labelleft=False)
        axarr[row,col].tick_params(labelbottom=False)
        if row==0:
            axarr[row,col].set_xlabel(labels[j],labelpad=5)
            axarr[row,col].xaxis.set_label_position('top')
        if col==0:
            axarr[row,col].set_ylabel(labels[i],labelpad=5)
            axarr[row,col].yaxis.set_label_position('left')

1 个答案:

答案 0 :(得分:1)

这是一个独立于实际图形尺寸的解决方案,并与图形一起缩放。

我们使用混合变换来指定图形坐标中线条的长度,同时指定左上方子图的轴坐标中的垂直或水平位置。因此,垂直线在y方向的图坐标中从0到1,而在第一个子图的轴坐标中它被绑定到x = 0。 然后我们还添加一个偏移变换,将其移动一半的线宽,使其紧靠轴刺。

enter image description here

from matplotlib import pyplot as plt
import matplotlib.lines
import matplotlib.transforms as transforms
import numpy as np

data = np.random.random_sample((10,10))
labels = "Some labels around all the subplots"
fig, axarr = plt.subplots(3,3, sharex='col', sharey='row')
for i, ax in enumerate(axarr.flatten()):
    ax.scatter(data[:,i], data[:,i+1])
    ax.xaxis.set_label_position('top')
for i in range(3):
    axarr[2-i,0].set_ylabel(labels.split()[i])
    axarr[0,i].set_xlabel(labels.split()[i+3])
    axarr[2-i,0].set_yticklabels([])

#### Create lines ####
lw=4  # linewidth in points
#vertical line
offset1 = transforms.ScaledTranslation(-lw/72./2., 0, fig.dpi_scale_trans)
trans1 = transforms.blended_transform_factory(
    axarr[0,0].transAxes +offset1, fig.transFigure)
l1 = matplotlib.lines.Line2D([0,0], [0, 1], transform=trans1,
            figure=fig, color="#dd0000", linewidth=4, zorder=0)
#horizontal line
offset2 = transforms.ScaledTranslation(0,lw/72./2., fig.dpi_scale_trans)
trans2 = transforms.blended_transform_factory(
     fig.transFigure, axarr[0,0].transAxes+offset2)
l2 = matplotlib.lines.Line2D([0, 1], [1,1], transform=trans2,
            figure=fig, color="#dd0000", linewidth=4, zorder=0)
#add lines to canvas
fig.lines.extend([l1, l2])

plt.show()

这是一个较小的版本,从中可以看出线条的位置适应图形尺寸。

enter image description here