使用seaborn

时间:2018-02-20 20:16:44

标签: python pandas matplotlib seaborn

我有一个名为recast的数据框,其结构如下:

data = [[ 374000,np.nan,749500,np.nan],
            [np.nan,np.nan,298000,np.nan],
            [540065.326633,np.nan,904750,np.nan],
            [np.nan,np.nan,514000,np.nan],
            [411000,np.nan,np.nan,np.nan]]
cols = pd.MultiIndex.from_tuples([('EUR, Oil (bbls)',c) for c in ['DE WITT','FAYETTE','GONZALES','LAVACA']], names=('','County'))
index = ['1776 ENERGY','ALMS','BURLINGTON','BXP','CHESAPEAKE']

recast=pd.DataFrame(data,index=index,columns=cols)
recast.index.name = 'ShortName'

print(recast)

            EUR, Oil (bbls)                         
 County             DE WITT FAYETTE  GONZALES LAVACA
ShortName                                           
1776 ENERGY   374000.000000     NaN  749500.0    NaN
ALMS                    NaN     NaN  298000.0    NaN
BURLINGTON    540065.326633     NaN  904750.0    NaN
BXP                     NaN     NaN  514000.0    NaN
CHESAPEAKE    411000.000000     NaN       NaN    NaN

我想做的是制作一个只有县名而不是它给我的列名的热图。当我在数据框上调用.info()时,我得到:

<class 'pandas.core.frame.DataFrame'>
Index: 29 entries, 1776 ENERGY to VERDUN OIL
Data columns (total 4 columns):
(EUR, Oil (bbls), DE WITT)     14 non-null float64
(EUR, Oil (bbls), FAYETTE)     3 non-null float64
(EUR, Oil (bbls), GONZALES)    23 non-null float64
(EUR, Oil (bbls), LAVACA)      5 non-null float64
dtypes: float64(4)

因此,默认情况下,热图上的我的x标签是:EUR, Oil (bbls), DE WITT,当我只想要:DE WITT

我尝试通过尝试重命名数据框中的列并尝试设置x_label来尝试消除不需要的部分,但无法让它工作。对于绘图我使用以下代码:

fig = plt.figure(71)
fig = sns.heatmap(recast, cmap='coolwarm', linewidths=0.25, linecolor='black')
fig = plt.xticks(rotation=0)
plt.tight_layout
plt.show()

另外,我想删除标签下方/左侧出现的X和Y标题(?)。

1 个答案:

答案 0 :(得分:0)

您可以使用.xs()来实现目标(see documentation)。

使用以下示例:

data = [[ 374000,np.nan,749500,np.nan],
        [np.nan,np.nan,298000,np.nan],
        [540065.326633,np.nan,904750,np.nan],
        [np.nan,np.nan,514000,np.nan],
        [411000,np.nan,np.nan,np.nan]]
cols = pd.MultiIndex.from_tuples([('EUR, Oil (bbls)',c) for c in ['DE WITT','FAYETTE','GONZALES','LAVACA']], names=('','County'))
index = ['1776 ENERGY','ALMS','BURLINGTON','BXP','CHESAPEAKE']

recast=pd.DataFrame(data,index=index,columns=cols)
recast.index.name = 'ShortName'

print(recast)
            EUR, Oil (bbls)                         
County              DE WITT FAYETTE  GONZALES LAVACA
ShortName                                           
1776 ENERGY   374000.000000     NaN  749500.0    NaN
ALMS                    NaN     NaN  298000.0    NaN
BURLINGTON    540065.326633     NaN  904750.0    NaN
BXP                     NaN     NaN  514000.0    NaN
CHESAPEAKE    411000.000000     NaN       NaN    NaN

print(recast.xs('EUR, Oil (bbls)',axis=1, drop_level=True))
#axis = 1 for column multi-index selection
#drop_level = True to drop 'EUR Oil (bbls)' column level

County             DE WITT  FAYETTE  GONZALES  LAVACA
ShortName                                            
1776 ENERGY  374000.000000      NaN  749500.0     NaN
ALMS                   NaN      NaN  298000.0     NaN
BURLINGTON   540065.326633      NaN  904750.0     NaN
BXP                    NaN      NaN  514000.0     NaN
CHESAPEAKE   411000.000000      NaN       NaN     NaN

要绘制你想要的内容,我用最小集合重写你的绘图部分:

_recast = recast.xs('EUR, Oil (bbls)',axis=1, drop_level=True)
fig = sns.heatmap(_recast, cmap='coolwarm', linewidths=0.25, linecolor='black')
fig.set_xlabel(' ')
fig.set_ylabel(' ')
plt.show()

recast heatmap

请注意,要删除X和Y标签,您只需使用fig.set_[x or y]label()函数重命名它们: