我喜欢PerformanceAnalytics
R套餐chart.Correlation
function中的这种相关矩阵:
如何在Python中创建它?我见过的相关矩阵图主要是热图,例如this seaborn
example。
答案 0 :(得分:5)
下面的cor_matrix
函数执行此操作,并添加了一个双变量内核密度图。感谢@ karl-anka关于让我开始的评论。
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
sns.set(style='white')
iris = sns.load_dataset('iris')
def corrfunc(x, y, **kws):
r, p = stats.pearsonr(x, y)
p_stars = ''
if p <= 0.05:
p_stars = '*'
if p <= 0.01:
p_stars = '**'
if p <= 0.001:
p_stars = '***'
ax = plt.gca()
ax.annotate('r = {:.2f} '.format(r) + p_stars,
xy=(0.05, 0.9), xycoords=ax.transAxes)
def annotate_colname(x, **kws):
ax = plt.gca()
ax.annotate(x.name, xy=(0.05, 0.9), xycoords=ax.transAxes,
fontweight='bold')
def cor_matrix(df):
g = sns.PairGrid(df, palette=['red'])
# Use normal regplot as `lowess=True` doesn't provide CIs.
g.map_upper(sns.regplot, scatter_kws={'s':10})
g.map_diag(sns.distplot)
g.map_diag(annotate_colname)
g.map_lower(sns.kdeplot, cmap='Blues_d')
g.map_lower(corrfunc)
# Remove axis labels, as they're in the diagonals.
for ax in g.axes.flatten():
ax.set_ylabel('')
ax.set_xlabel('')
return g
cor_matrix(iris)
答案 1 :(得分:5)
另一种解决方案是
import matplotlib.pyplot as plt
import seaborn as sns
def corrdot(*args, **kwargs):
corr_r = args[0].corr(args[1], 'pearson')
corr_text = f"{corr_r:2.2f}".replace("0.", ".")
ax = plt.gca()
ax.set_axis_off()
marker_size = abs(corr_r) * 10000
ax.scatter([.5], [.5], marker_size, [corr_r], alpha=0.6, cmap="coolwarm",
vmin=-1, vmax=1, transform=ax.transAxes)
font_size = abs(corr_r) * 40 + 5
ax.annotate(corr_text, [.5, .5,], xycoords="axes fraction",
ha='center', va='center', fontsize=font_size)
sns.set(style='white', font_scale=1.6)
iris = sns.load_dataset('iris')
g = sns.PairGrid(iris, aspect=1.4, diag_sharey=False)
g.map_lower(sns.regplot, lowess=True, ci=False, line_kws={'color': 'black'})
g.map_diag(sns.distplot, kde_kws={'color': 'black'})
g.map_upper(corrdot)
现在,如果您真的想要模仿该R图的外观,您可以将上述内容与您提供的一些解决方案结合起来:
import matplotlib.pyplot as plt
from scipy import stats
import seaborn as sns
import numpy as np
def corrdot(*args, **kwargs):
corr_r = args[0].corr(args[1], 'pearson')
corr_text = round(corr_r, 2)
ax = plt.gca()
font_size = abs(corr_r) * 80 + 5
ax.annotate(corr_text, [.5, .5,], xycoords="axes fraction",
ha='center', va='center', fontsize=font_size)
def corrfunc(x, y, **kws):
r, p = stats.pearsonr(x, y)
p_stars = ''
if p <= 0.05:
p_stars = '*'
if p <= 0.01:
p_stars = '**'
if p <= 0.001:
p_stars = '***'
ax = plt.gca()
ax.annotate(p_stars, xy=(0.65, 0.6), xycoords=ax.transAxes,
color='red', fontsize=70)
sns.set(style='white', font_scale=1.6)
iris = sns.load_dataset('iris')
g = sns.PairGrid(iris, aspect=1.5, diag_sharey=False, despine=False)
g.map_lower(sns.regplot, lowess=True, ci=False,
line_kws={'color': 'red', 'lw': 1},
scatter_kws={'color': 'black', 's': 20})
g.map_diag(sns.distplot, color='black',
kde_kws={'color': 'red', 'cut': 0.7, 'lw': 1},
hist_kws={'histtype': 'bar', 'lw': 2,
'edgecolor': 'k', 'facecolor':'grey'})
g.map_diag(sns.rugplot, color='black')
g.map_upper(corrdot)
g.map_upper(corrfunc)
g.fig.subplots_adjust(wspace=0, hspace=0)
# Remove axis labels
for ax in g.axes.flatten():
ax.set_ylabel('')
ax.set_xlabel('')
# Add titles to the diagonal axes/subplots
for ax, col in zip(np.diag(g.axes), iris.columns):
ax.set_title(col, y=0.82, fontsize=26)
与chart.Correlation()
在R中设置iris
数据集的方式非常接近:
library(PerformanceAnalytics)
chart.Correlation(data.matrix(iris[, -5]), histogram = TRUE, pch=20)
答案 2 :(得分:0)
解决问题“'numpy.ndarray' object has no attribute 'name'”错误在“ax.annotate(x.name, xy=(0.05, 0.9), xycoords=ax.transAxes, fontweight='粗体')”,为了保持通用性,在 cor_matrix 函数内部构建一个迭代函数,并将 annnotate_col 函数移动到 cor_matrix 函数中,如下所示。
def corrfunc(x, y, **kws):
r, p = stats.pearsonr(x, y)
p_stars = ''
if p <= 0.05:
p_stars = '*'
if p <= 0.01:
p_stars = '**'
if p <= 0.001:
p_stars = '***'
ax = plt.gca()
ax.annotate('r = {:.2f} '.format(r) + p_stars, xy=(0.05, 0.9), ycoords=ax.transAxes)
def cor_matrix(df, save=False):
# ======= NEW ITERATION FUNCTION ====
label_iter = iter(df).__next__
# ====================================
def annotate_colname(x, **kws):
ax = plt.gca()
# ===== GHANGE below x.name by label_iter() ======
ax.annotate(label_iter(), xy=(0.05, 0.9), xycoords=ax.transAxes, fontweight='bold')
g = sns.PairGrid(df, palette=['red'])
# Use normal regplot as `lowess=True` doesn't provide CIs.
g.map_upper(sns.regplot, scatter_kws={'s':10}, line_kws={"color": "red"})
g.map_diag(sns.histplot, kde=True) # fix deprecated message
g.map_diag(annotate_colname)
g.map_lower(sns.kdeplot, cmap='Blues_d')
g.map_lower(corrfunc)
# Remove axis labels, as they're in the diagonals.
for ax in g.axes.flatten():
ax.set_ylabel('')
ax.set_xlabel('')
if save:
plt.savefig('corr_mat.png')
return g