考虑这种数据文件:
数据-file.txt的
75,15,1,57.5,9.9,5
75,15,1,58.1,10.0,5
75,15,2,37.9,8.3,5
75,15,2,18.2,7.3,5
150,15,1,26.4,8.3,10
150,15,1,31.6,7.9,10
150,15,2,30.6,7.5,10
150,15,2,25.1,7.1,10
观察第3列值仅为1,2
。
我想生成3x2
- 直方图网格。下面的子图看起来是正确的,但每行应该包含来自不同数据集的2个直方图,我的意思是,我根据最后一列过滤数据。
重要的代码是ax.hist(X[ (y==grp) & (X[:,2]==1), cols],
,其中包含过滤器。
我想要每行2个直方图:
(X[:,2]== * )
的第一行,其中*
是第3列(1或2)中的任意值,(X[:,2]==1)
和(X[:,2]==2)
。在简历中,我希望得到过滤数据的第2行,第3行直方图:
第3列值= 1
75,15,1,57.5,9.9,5
75,15,1,58.1,10.0,5
150,15,1,26.4,8.3,10
150,15,1,31.6,7.9,10
第3列值= 2
75,15,2,37.9,8.3,5
75,15,2,18.2,7.3,5
150,15,2,30.6,7.5,10
150,15,2,25.1,7.1,10
代码:
#!/usr/bin/python
# -*- coding: utf-8 -*-
import pandas as pd
import numpy as np
import math
from matplotlib import pyplot as plt
from itertools import combinations
data_file='data-file.txt'
df = pd.io.parsers.read_csv(
filepath_or_buffer=data_file,
delim_whitespace=False,
)
M, N = df.shape[0], df.shape[1]
feature_dict = {i+1:label for i,label in zip(
range(N),
('L',
'A',
'G',
'P',
'T',
'PP',
))}
df.columns = [l for i,l in sorted(feature_dict.items())]
X = df[range(N-1)].values
y = df['PP'].values
label_dict = dict(enumerate(sorted(list(set(y)))))
label_dict = {x+1:y for x,y in label_dict.iteritems()}
num_grupos = len(label_dict.keys())
grps_to_hist_list = [[j for j in i] for i in combinations(label_dict.keys(), 2)]
grps_to_hist_list_values = [[j for j in i] for i in combinations(label_dict.values(), 2)]
cols_to_hist = [3, 4]
for grps_to_hist in grps_to_hist_list:
grps_str = [ label_dict[grps_to_hist[0]], label_dict[grps_to_hist[1]] ]
print 'creating histogram for groups %s from data file %s' % (grps_str , data_file)
fig, axes = plt.subplots(nrows=3, ncols=2, figsize=(18,8))
for ax,cols in zip(axes.ravel(), cols_to_hist):
# set bin sizes
min_b = math.floor(np.min(X[:,cols]))
max_b = math.ceil(np.max(X[:,cols]))
bins = np.linspace(min_b, max_b, 40)
# ploting the histograms
#"""
for grp,color in zip( grps_str, ('blue', 'red')):
ax.hist(X[ (y==grp) & (X[:,2]==1), cols],
color=color,
label='%s' % grp,
bins=bins,
alpha=0.3,)
ylims = ax.get_ylim()
# plot annotation
leg = ax.legend(loc='upper right', fancybox=True, fontsize=8)
leg.get_frame().set_alpha(0.5)
ax.set_ylim([0, max(ylims)+2])
ax.set_xlabel(feature_dict[cols+1])
ax.set_title('%s' % str(data_file))
# hide axis ticks
ax.tick_params(axis="both", which="both", bottom="off", top="off", labelbottom="on", left="off", right="off", labelleft="on")
# remove axis spines
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["bottom"].set_visible(False)
ax.spines["left"].set_visible(False)
#"""
fig.tight_layout()
plt.show()
以下是使用过滤器 (y==grp) & (X[:,2]==1)
(应该在第2行)的上述代码的屏幕截图。
答案 0 :(得分:1)
我的逻辑是使用您选择的相应掩码[(X[:,2]==1) | (X[:,2]==2), X[:,2]==1, X[:,2]==2]
迭代行。希望这是你想要的:
#!/usr/bin/python
# -*- coding: utf-8 -*-
import pandas as pd
import numpy as np
import math
from matplotlib import pyplot as plt
from itertools import combinations
data_file='data-file.txt'
df = pd.io.parsers.read_csv(
filepath_or_buffer=data_file,
delim_whitespace=False,
)
M, N = df.shape[0], df.shape[1]
feature_dict = {i+1:label for i,label in zip(
range(N),
('L',
'A',
'G',
'P',
'T',
'PP',
))}
df.columns = [l for i,l in sorted(feature_dict.items())]
X = df[range(N-1)].values
y = df['PP'].values
label_dict = dict(enumerate(sorted(list(set(y)))))
label_dict = {x+1:y for x,y in label_dict.iteritems()}
num_grupos = len(label_dict.keys())
grps_to_hist_list = [[j for j in i] for i in combinations(label_dict.keys(), 2)]
grps_to_hist_list_values = [[j for j in i] for i in combinations(label_dict.values(), 2)]
cols_to_hist = [3, 4]
for grps_to_hist in grps_to_hist_list:
grps_str = [ label_dict[grps_to_hist[0]], label_dict[grps_to_hist[1]] ]
print 'creating histogram for groups %s from data file %s' % (grps_str , data_file)
fig, axes = plt.subplots(nrows=3, ncols=2, figsize=(18,8))
for row_ax, row_mask in zip(axes, [(X[:,2]==1) | (X[:,2]==2), X[:,2]==1, X[:,2]==2]):
for ax,cols in zip(row_ax, cols_to_hist):
# set bin sizes
min_b = math.floor(np.min(X[:,cols]))
max_b = math.ceil(np.max(X[:,cols]))
bins = np.linspace(min_b, max_b, 40)
# ploting the histograms
#"""
for grp,color in zip( grps_str, ('blue', 'red')):
ax.hist(X[ (y==grp) & row_mask, cols],
color=color,
label='%s' % grp,
bins=bins,
alpha=0.3,)
ylims = ax.get_ylim()
# plot annotation
leg = ax.legend(loc='upper right', fancybox=True, fontsize=8)
leg.get_frame().set_alpha(0.5)
ax.set_ylim([0, max(ylims)+2])
ax.set_xlabel(feature_dict[cols+1])
ax.set_title('%s' % str(data_file))
# hide axis ticks
ax.tick_params(axis="both", which="both", bottom="off", top="off", labelbottom="on", left="off", right="off", labelleft="on")
# remove axis spines
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["bottom"].set_visible(False)
ax.spines["left"].set_visible(False)
#"""
fig.tight_layout()
plt.show()