我在可视化的数据集中有9个输入和2个输出。我正在使用GridSpec
绘制图表。我根据输出绘制了针对每个输入绘制的散点图,并针对所有输出绘制了针对所有输入绘制的龙卷风图。见下图
从上图中可以看到2行散点图(x0和x00),然后是龙卷风图。
问题:是否可以在每行散点图的末尾添加龙卷风图?
这是我的代码:
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import gridspec
dataset1 = np.genfromtxt('dataSet1.csv', dtype = float, delimiter = ',', names = True)
li_input = []
li_output = []
for i in dataset1.dtype.names:
if i.startswith('x'):
li_output.append(i)
else:
li_input.append(i)
print('Input => {}\n'.format(li_input))
print('Output => {}\n'.format(li_output))
corr_list = []
corr_dict = {}
for i in li_output:
for j in li_input:
corr = np.corrcoef(dataset1[j],dataset1[i])
corr_list.append(corr[0,1])
corr_dict[j + '_'+ str((i))] = (corr[0,1])
###vis###
fig = plt.figure(figsize = (8,8))
gs = gridspec.GridSpec(3,len(li_input), height_ratios=[1,1,1])
gs1 = gridspec.GridSpec(3,1, height_ratios=[1,1,1])
def tornado(variables,values,ax):
np.set_printoptions(precision=4)
variables = variables
base = 0
values = values
variables=zip(*sorted(zip(variables, values),reverse = True, key=lambda x: abs(x[1])))[0]
values = sorted(values,key=abs, reverse=True)
# Y position for each variable
ys = range(len(values))[::-1] # top to bottom
# Plot the bars, one by one
for y, value in zip(ys, values):
high_width = base + value
# Each bar is a "broken" horizontal bar chart
ax.broken_barh(
[(base, high_width)],
(y - 0.4, 0.8),
facecolors=['red', 'red'], # Try different colors if you like
edgecolors=['black', 'black'],
linewidth=1)
# Draw a vertical line down the middle
plt.axvline(base, color='black')
# Position the x-axis on the top/bottom, hide all the other spines (=axis lines)
axes = plt.gca() # (gca = get current axes)
axes.spines['left'].set_visible(False)
axes.spines['right'].set_visible(False)
axes.spines['top'].set_visible(False)
axes.xaxis.set_ticks_position('bottom')
# Make the y-axis display the variables
plt.yticks(ys, variables)
plt.tick_params(axis='y', which='both', labelleft='off', labelright='on')
# Set the portion of the x- and y-axes to show
plt.xlim(-1,1)
plt.ylim(-2, len(variables))
plt.draw()
return
def plot_correlation():
lenli = len(li_input)
flag = -lenli
for j in range(len(li_output)):
flag+=lenli
for i in range(len(li_input)):
ax = fig.add_subplot(gs[i+flag])
ax.scatter(dataset1[li_input[i]],dataset1[li_output[j]],marker = '.')
ax.set_xlabel(li_input[i])
ax.set_ylabel(li_output[j])
def plot_op():
num = 2
ax1= plt.subplot(gs1[num])
variables_op1 = [x for x in corr_dict.keys()]
values_op1 = np.array([x for x in corr_dict.values()])
tornado(variables_op1,values_op1,ax1)
return
plot_correlation()
plot_op()
plt.show()
非常感谢任何帮助。