Matplotlib在pandas数据帧上使用`scatter_matrix`堆积直方图

时间:2017-10-06 23:12:33

标签: python pandas matplotlib histogram

目前我有以下代码

import matplotlib.pyplot as plt
import pandas as pd
from pandas.plotting import scatter_matrix

df= pd.read_csv(file, sep=',')
colors = list('r' if i==1 else 'b' for i in df['class']) # class is either 1 or 0
plt.figure()
scatter_matrix(df, color=colors )
plt.show()

显示以下输出

enter image description here

但是在这个关于对角线的情节中,我想要显示堆叠的直方图而不是简单的直方图,如下所示,对于类'1'它是红色而对于'0'它是蓝色

enter image description here

请指导我该怎么做?

2 个答案:

答案 0 :(得分:2)

使用seaborn可能非常有利于绘制散射矩阵类图。但是,我不知道如何将堆积的直方图轻松地绘制到seaborn中PairGrid的对角线上。
正如问题无论如何要求matplotlib,以下是使用pandas和matplotlib的解决方案。不幸的是,它需要手工做很多事情。以下是一个例子(请注意,只导入seaborn以获取一些数据,因为问题没有提供任何数据)。

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# seaborn import just needed to get some data
import seaborn as sns
df = sns.load_dataset("iris")


n_hist = 10
category = "species"
columns = ["sepal_length","sepal_width","petal_length","petal_width"]
mi = df[columns].values.min()
ma = df[columns].values.max()
hist_bins = np.linspace(mi, ma, n_hist)


fig, axes = plt.subplots(nrows=len(columns), ncols=len(columns), 
                         sharex="col")

for i,row in enumerate(columns):
    for j,col in enumerate(columns):
        ax= axes[i,j]
        if i == j:
            # diagonal
            mi = df[col].values.min()
            ma = df[col].values.max()
            hist_bins = np.linspace(mi, ma, n_hist)
            def hist(x):
                h, e = np.histogram(x.dropna()[col], bins=hist_bins)
                return pd.Series(h, e[:-1])
            b = df[[col,category]].groupby(category).apply(hist).T
            values = np.cumsum(b.values, axis=1)
            for k in range(len(b.columns)):
                if k == 0:
                    ax.bar(b.index, values[:,k], width=np.diff(hist_bins)[0])
                else:
                    ax.bar(b.index, values[:,k], width=np.diff(hist_bins)[0],
                           bottom=values[:,k-1])
        else:
            # offdiagonal
            for (n,cat) in df.groupby(category):
                ax.scatter(cat[col],cat[row], s = 5,label=n, )
        ax.set_xlabel(col)
        ax.set_ylabel(row)
        #ax.legend()
plt.tight_layout()
plt.show()

enter image description here

答案 1 :(得分:1)

示例代码

import seaborn as sns
sns.set(style="ticks")
df = sns.load_dataset("iris")
sns.pairplot(df, hue="species")

enter image description here