是否有在matplotlib中制作散点图矩阵的函数?

时间:2011-10-29 19:36:42

标签: python matplotlib scatter-plot

散点图矩阵的示例

enter image description here

matplotlib.pyplot中是否有这样的功能?

5 个答案:

答案 0 :(得分:95)

对于那些不想定义自己的功能的人,Python中有一个很棒的数据分析库,名为Pandas,可以找到scatter_matrix()方法:

from pandas.plotting import scatter_matrix
df = pd.DataFrame(np.random.randn(1000, 4), columns = ['a', 'b', 'c', 'd'])
scatter_matrix(df, alpha = 0.2, figsize = (6, 6), diagonal = 'kde')

enter image description here

答案 1 :(得分:20)

一般来说,matplotlib通常不包含在多个轴对象上操作的绘图函数(在本例中为子图)。期望的是你会编写一个简单的函数来将所有东西串在一起,无论你喜欢什么。

我不太确定你的数据是什么样子,但是从头开始构建一个函数来实现这一点非常简单。如果您总是要使用结构化数据或rec数组,那么您可以简化这一过程。 (即,每个数据系列始终都有一个名称,因此您可以省略必须指定名称。)

举个例子:

import itertools
import numpy as np
import matplotlib.pyplot as plt

def main():
    np.random.seed(1977)
    numvars, numdata = 4, 10
    data = 10 * np.random.random((numvars, numdata))
    fig = scatterplot_matrix(data, ['mpg', 'disp', 'drat', 'wt'],
            linestyle='none', marker='o', color='black', mfc='none')
    fig.suptitle('Simple Scatterplot Matrix')
    plt.show()

def scatterplot_matrix(data, names, **kwargs):
    """Plots a scatterplot matrix of subplots.  Each row of "data" is plotted
    against other rows, resulting in a nrows by nrows grid of subplots with the
    diagonal subplots labeled with "names".  Additional keyword arguments are
    passed on to matplotlib's "plot" command. Returns the matplotlib figure
    object containg the subplot grid."""
    numvars, numdata = data.shape
    fig, axes = plt.subplots(nrows=numvars, ncols=numvars, figsize=(8,8))
    fig.subplots_adjust(hspace=0.05, wspace=0.05)

    for ax in axes.flat:
        # Hide all ticks and labels
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)

        # Set up ticks only on one side for the "edge" subplots...
        if ax.is_first_col():
            ax.yaxis.set_ticks_position('left')
        if ax.is_last_col():
            ax.yaxis.set_ticks_position('right')
        if ax.is_first_row():
            ax.xaxis.set_ticks_position('top')
        if ax.is_last_row():
            ax.xaxis.set_ticks_position('bottom')

    # Plot the data.
    for i, j in zip(*np.triu_indices_from(axes, k=1)):
        for x, y in [(i,j), (j,i)]:
            axes[x,y].plot(data[x], data[y], **kwargs)

    # Label the diagonal subplots...
    for i, label in enumerate(names):
        axes[i,i].annotate(label, (0.5, 0.5), xycoords='axes fraction',
                ha='center', va='center')

    # Turn on the proper x or y axes ticks.
    for i, j in zip(range(numvars), itertools.cycle((-1, 0))):
        axes[j,i].xaxis.set_visible(True)
        axes[i,j].yaxis.set_visible(True)

    return fig

main()

enter image description here

答案 2 :(得分:12)

您还可以使用Seaborn's pairplot function

import seaborn as sns
sns.set()
df = sns.load_dataset("iris")
sns.pairplot(df, hue="species")

答案 3 :(得分:10)

感谢您分享您的代码!你为我们找到了所有困难的东西。当我使用它时,我注意到一些看起来不太正确的小事。

  1. [FIX#1]轴抽搐没有像我预期的那样排列(也就是说,在上面的例子中,你应该能够在所有图中的任意点绘制垂直和水平线,并且线应该穿过其他图中的对应点,但现在它不会发生。

  2. [FIX#2]如果您正在绘制奇数个变量,则右下角轴不会拉出正确的xtics或ytics。它只是将其保留为默认的0..1刻度。

  3. 不是修复,但我明确输入names使其成为可选项,以便在对角线位置为变量i设置默认xi

  4. 您将在下面找到解决这两点的代码的更新版本,否则将保留代码的美感。

    import itertools
    import numpy as np
    import matplotlib.pyplot as plt
    
    def scatterplot_matrix(data, names=[], **kwargs):
        """
        Plots a scatterplot matrix of subplots.  Each row of "data" is plotted
        against other rows, resulting in a nrows by nrows grid of subplots with the
        diagonal subplots labeled with "names".  Additional keyword arguments are
        passed on to matplotlib's "plot" command. Returns the matplotlib figure
        object containg the subplot grid.
        """
        numvars, numdata = data.shape
        fig, axes = plt.subplots(nrows=numvars, ncols=numvars, figsize=(8,8))
        fig.subplots_adjust(hspace=0.0, wspace=0.0)
    
        for ax in axes.flat:
            # Hide all ticks and labels
            ax.xaxis.set_visible(False)
            ax.yaxis.set_visible(False)
    
            # Set up ticks only on one side for the "edge" subplots...
            if ax.is_first_col():
                ax.yaxis.set_ticks_position('left')
            if ax.is_last_col():
                ax.yaxis.set_ticks_position('right')
            if ax.is_first_row():
                ax.xaxis.set_ticks_position('top')
            if ax.is_last_row():
                ax.xaxis.set_ticks_position('bottom')
    
        # Plot the data.
        for i, j in zip(*np.triu_indices_from(axes, k=1)):
            for x, y in [(i,j), (j,i)]:
                # FIX #1: this needed to be changed from ...(data[x], data[y],...)
                axes[x,y].plot(data[y], data[x], **kwargs)
    
        # Label the diagonal subplots...
        if not names:
            names = ['x'+str(i) for i in range(numvars)]
    
        for i, label in enumerate(names):
            axes[i,i].annotate(label, (0.5, 0.5), xycoords='axes fraction',
                    ha='center', va='center')
    
        # Turn on the proper x or y axes ticks.
        for i, j in zip(range(numvars), itertools.cycle((-1, 0))):
            axes[j,i].xaxis.set_visible(True)
            axes[i,j].yaxis.set_visible(True)
    
        # FIX #2: if numvars is odd, the bottom right corner plot doesn't have the
        # correct axes limits, so we pull them from other axes
        if numvars%2:
            xlimits = axes[0,-1].get_xlim()
            ylimits = axes[-1,0].get_ylim()
            axes[-1,-1].set_xlim(xlimits)
            axes[-1,-1].set_ylim(ylimits)
    
        return fig
    
    if __name__=='__main__':
        np.random.seed(1977)
        numvars, numdata = 4, 10
        data = 10 * np.random.random((numvars, numdata))
        fig = scatterplot_matrix(data, ['mpg', 'disp', 'drat', 'wt'],
                linestyle='none', marker='o', color='black', mfc='none')
        fig.suptitle('Simple Scatterplot Matrix')
        plt.show()
    

    再次感谢您与我们分享此内容。我已经多次使用它了!哦,我重新安排了代码的main()部分,以便它可以是一个正式的示例代码,如果它被导入到另一段代码中则不会被调用。

答案 4 :(得分:4)

在阅读问题时,我希望看到答案,包括rpy。我认为这是一个很好的选择,利用两种美丽的语言。所以这就是:

import rpy
import numpy as np

def main():
    np.random.seed(1977)
    numvars, numdata = 4, 10
    data = 10 * np.random.random((numvars, numdata))
    mpg = data[0,:]
    disp = data[1,:]
    drat = data[2,:]
    wt = data[3,:]
    rpy.set_default_mode(rpy.NO_CONVERSION)

    R_data = rpy.r.data_frame(mpg=mpg,disp=disp,drat=drat,wt=wt)

    # Figure saved as eps
    rpy.r.postscript('pairsPlot.eps')
    rpy.r.pairs(R_data,
       main="Simple Scatterplot Matrix Via RPy")
    rpy.r.dev_off()

    # Figure saved as png
    rpy.r.png('pairsPlot.png')
    rpy.r.pairs(R_data,
       main="Simple Scatterplot Matrix Via RPy")
    rpy.r.dev_off()

    rpy.set_default_mode(rpy.BASIC_CONVERSION)


if __name__ == '__main__': main()

我无法发布图片以显示结果:(对不起!