用seaborn stripplot中的色调绘制宽幅矩阵

时间:2016-12-02 15:41:12

标签: python seaborn

我正在尝试使用stripplot绘制数据集。这是头部(有25列):

    Labels  Acidobacteria  Actinobacteria  Armatimonadetes  Bacteroidetes  
0       0              0             495              NaN          27859   
1       1              0            1256              NaN          46582
2       0              0            1081              NaN          23798   
3       1              0            2523              NaN          35088   
4       0              0            1383              NaN          19338  

我将此数据集存储在pandas DataFrame中,并可以使用以下方式绘制它:

   def plot():
    ax = sns.stripplot(data = df)
    ax.set(xlabel='Bacteria',ylabel='Abundance')
    plt.setp(ax.get_xticklabels(),rotation=45)
    plt.show()

制作this plot

我想设置色调以反映'Labels'列。当我尝试:

sns.stripplot(x=df.columns.values.tolist(),y=df,data=df,hue='Labels') 

我明白了:

ValueError: cannot copy sequence with size 26 to array axis with dimension 830

2 个答案:

答案 0 :(得分:3)

所以我明白了。我不得不通过堆叠和重新索引来重新排列数据:

cols = df.columns.values.tolist()[3:]
stacked = df[cols].stack().reset_index()
stacked.rename(columns={'level_0':'index','level_1':'Bacteria',0:'Abundance'},inplace=True)

哪个输出:

           index          Bacteria  Abundance
0          0     Acidobacteria   0.000000
1          0    Actinobacteria   0.005003
2          0   Armatimonadetes   0.000000
3          0     Bacteroidetes   0.281586

接下来,我必须创建一个新列,为每个数据点分配标签:

label_col = np.array([[label for _ in range(len(cols))] for label in df['Labels']])
label_col = label_col.flatten()

stacked['Labels'] = label_col

现在:

   index         Bacteria  Abundance  Labels
0      0    Acidobacteria   0.000000       0
1      0   Actinobacteria   0.005003       0
2      0  Armatimonadetes   0.000000       0
3      0    Bacteroidetes   0.281586       0
4      0       Chlamydiae   0.000000       0

然后绘制:

def plot():
    ax = sns.stripplot(x='Bacteria',y='Abundance',data=stacked,hue='Labels',jitter=True)
    ax.set(xlabel='Bacteria',ylabel='Abundance')
    plt.setp(ax.get_xticklabels(),rotation=45)
    plt.show()
plot()

制作this graph

感谢您的帮助!

答案 1 :(得分:0)

我想扩展您的答案(实际上,我会对其进行压缩),因为这可以用“单线”完成:

# To select specific columns:
cols = ["Acidobacteria", "Actinobacteria", "Armatimonadetes", "Bacteroidetes"]
df.set_index("Labels")[cols]\
    .stack()\
    .reset_index()\
    .rename(columns={'level_1':'Bacteria', 0:'Abundance'})

# If you want to stack all columns but "Labels", this is enough:
df.set_index("Labels")\
    .stack()\
    .reset_index()\
    .rename(columns={'level_1':'Bacteria', 0:'Abundance'})

避免重新创建"Labels"列的诀窍是在堆叠之前将其设置为索引

输出:

    Labels        Bacteria  Abundance
0        0   Acidobacteria        0.0
1        0  Actinobacteria      495.0
2        0   Bacteroidetes    27859.0
3        1   Acidobacteria        0.0
4        1  Actinobacteria     1256.0
5        1   Bacteroidetes    46582.0
6        0   Acidobacteria        0.0
7        0  Actinobacteria     1081.0
8        0   Bacteroidetes    23798.0
9        1   Acidobacteria        0.0
10       1  Actinobacteria     2523.0
11       1   Bacteroidetes    35088.0
12       0   Acidobacteria        0.0
13       0  Actinobacteria     1383.0
14       0   Bacteroidetes    19338.0