Seaborn热图中的重叠y轴标签

时间:2020-10-07 15:02:54

标签: python pandas matplotlib seaborn heatmap

如何扩展进一步使单元格在它们之间创建任何白色边框?

y轴上的日期重叠,我想将其散布开。

我试图增加列中的figsize,但是当我更改参数时图形没有改变。这是传播它的任何方法,并且两者之间没有边界吗?

import seaborn as sns
import pandas as pd
import datetime as dt
import matplotlib.pyplot as plt
import yfinance as yf

#====================================================
prev=150
endDate=dt.datetime.today().date()
sDate=endDate-pd.to_timedelta(prev,unit='d')
#--------------------------------------------------------------
def get_price(tickers,roll_num=20): #input is a list or Series
    result=pd.DataFrame()
    pic=pd.DataFrame()
    for i in tickers:
        try:
            df=pd.DataFrame()                
            df['Adj Close']=yf.download(i,sDate,endDate)['Adj Close']
            df['MA']=df['Adj Close'].rolling(roll_num).mean()
            df.sort_values(ascending=False,inplace=True,by="Date")  # sometimes error
            df['Higher?']=df['Adj Close']>df['MA']
            df['Higher?']=df['Higher?'].astype(int)
            result[str(i)]=df['Higher?']
            
        except Exception as ex:  # no date column
            print('Ticker', i, 'ERROR', ex)
            print(df)
    pic[tickers.name]=(result.sum(axis=1)/len(result.columns)*100).astype(int) 
    pic.name=tickers.name   
    pic.drop(pic.tail(roll_num-1).index,inplace=True)
    return pic
#--------------------------------------------------------------
test=pd.Series(['A','TSLA','KO','T','aapl','nke'])
test=test.str.replace('.','-')
test.name='I am test'
a=get_price(test)
print(a)
#=============================================================================

base_url = "http://www.sectorspdr.com/sectorspdr/IDCO.Client.Spdrs.Holdings/Export/ExportExcel?symbol="

data = {                    
    'Ticker' :      [ 'XLC','XLY','XLP','XLE','XLF','XLV','XLI','XLB','XLRE','XLK','XLU' ]          
,   'Name' :    [ 'Communication Services','Consumer Discretionary','Consumer Staples','Energy','Financials','Health Care','Industrials','Materials','Real Estate','Technology','Utilities' ]           
}                   

spdr_df = pd.DataFrame(data)     

print(spdr_df)

#-------------------------------------------------------------------
final_product=[]


for i, row in spdr_df.iterrows():
    url =  base_url + row['Ticker']
    df_url = pd.read_excel(url)
    header = df_url.iloc[0]
    holdings_df = df_url[1:]
    holdings_df.set_axis(header, axis='columns', inplace=True)
    holdings_df=holdings_df['Symbol'].str.replace('.','-')
    holdings_df.name=row.Name
    final_product.append(get_price(holdings_df))
    
    

final_product=pd.concat(final_product,axis=1)
final_product['Sum']=final_product.sum(axis=1)
final_product.index=final_product.index.strftime('%Y-%m-%d')
print(final_product)
#------------------------------------------------

#----------------------------
plt.rcParams['ytick.labelsize']=12
fontsize_pt = plt.rcParams['ytick.labelsize']
dpi = 72.27

column_labels = final_product.columns[:-1]

## comput the matrix height in points and inches
matrix_height_pt = fontsize_pt * final_product.shape[0]
matrix_height_in = matrix_height_pt / dpi

# compute the required figure height 
top_margin = 0.1  # in percentage of the figure height
bottom_margin = 0.04 # in percentage of the figure height
figure_height = matrix_height_in / (1 - top_margin - bottom_margin)


# build the figure instance with the desired height
fig, (ax1,ax2)= plt.subplots(ncols=2,figsize=(10,50), 
        gridspec_kw=dict(top=1-top_margin, bottom=bottom_margin,wspace=0.01))

# let seaborn do it's thing
cmap = sns.diverging_palette(20, 145)
ax1 = sns.heatmap(final_product[final_product.columns[:-1]],cmap=cmap, vmin=0,vmax=100,annot=True,xticklabels=column_labels, cbar=False, ax=ax1, fmt='.0f')
                 
ax2 = sns.heatmap(final_product[final_product.columns[-1:]], cmap=cmap, vmin=0, vmax=1100, annot=True, fmt='.0f',yticklabels=[], cbar=False, ax=ax2)
ax2.set_ylabel('')
ax2.tick_params(axis='x', labelrotation=90)
ax1.xaxis.tick_top()
ax1.xaxis.set_label_position('top')
ax1.tick_params(axis='x', labelrotation=45)
plt.savefig('heatmap.png')

我的输出图片如下: enter image description here

1 个答案:

答案 0 :(得分:0)

我认为您正在Object.keys(values).filter(key => values[key]) 中用参数gridspec_kw覆盖figsize。尝试更改gridspec_kw中的顶部和底部参数,或删除gridspec_kw。