Python:在每个子图行的末尾添加一个子图

时间:2016-10-26 16:40:02

标签: python python-2.7 matplotlib graph

我在可视化的数据集中有9个输入和2个输出。我正在使用GridSpec绘制图表。我根据输出绘制了针对每个输入绘制的散点图,并针对所有输出绘制了针对所有输入绘制的龙卷风图。见下图

enter image description here

从上图中可以看到2行散点图(x0和x00),然后是龙卷风图。

问题:是否可以在每行散点图的末尾添加龙卷风图?

这是我的代码:

import numpy as np
from matplotlib import pyplot as plt
from matplotlib import gridspec

dataset1 = np.genfromtxt('dataSet1.csv', dtype = float, delimiter = ',', names = True)
li_input = []
li_output = []


for i in dataset1.dtype.names:
    if i.startswith('x'):
        li_output.append(i)
    else:
        li_input.append(i)
print('Input => {}\n'.format(li_input))
print('Output => {}\n'.format(li_output))

corr_list = []
corr_dict = {}
for i in li_output:
    for j in li_input:
        corr = np.corrcoef(dataset1[j],dataset1[i])
        corr_list.append(corr[0,1])
        corr_dict[j + '_'+ str((i))] = (corr[0,1])

###vis###
fig = plt.figure(figsize = (8,8))
gs = gridspec.GridSpec(3,len(li_input), height_ratios=[1,1,1])
gs1 = gridspec.GridSpec(3,1, height_ratios=[1,1,1])


def tornado(variables,values,ax):
    np.set_printoptions(precision=4)
    variables = variables
    base = 0  
    values = values
    variables=zip(*sorted(zip(variables, values),reverse = True, key=lambda x: abs(x[1])))[0] 
    values = sorted(values,key=abs, reverse=True)

    # Y position for each variable
    ys = range(len(values))[::-1]  # top to bottom

    # Plot the bars, one by one
    for y, value in zip(ys, values):
        high_width = base + value

        # Each bar is a "broken" horizontal bar chart
        ax.broken_barh(
            [(base, high_width)],
            (y - 0.4, 0.8),
            facecolors=['red', 'red'],  # Try different colors if you like
            edgecolors=['black', 'black'],
            linewidth=1)

    # Draw a vertical line down the middle
    plt.axvline(base, color='black')

    # Position the x-axis on the top/bottom, hide all the other spines (=axis lines)
    axes = plt.gca()  # (gca = get current axes)
    axes.spines['left'].set_visible(False)
    axes.spines['right'].set_visible(False)
    axes.spines['top'].set_visible(False)
    axes.xaxis.set_ticks_position('bottom')

    # Make the y-axis display the variables
    plt.yticks(ys, variables)
    plt.tick_params(axis='y', which='both', labelleft='off', labelright='on')

    # Set the portion of the x- and y-axes to show
    plt.xlim(-1,1)
    plt.ylim(-2, len(variables))

    plt.draw()
    return


def plot_correlation():
    lenli = len(li_input)
    flag = -lenli
    for j in range(len(li_output)):
        flag+=lenli
        for i in range(len(li_input)):
            ax = fig.add_subplot(gs[i+flag])
            ax.scatter(dataset1[li_input[i]],dataset1[li_output[j]],marker = '.')
            ax.set_xlabel(li_input[i])
            ax.set_ylabel(li_output[j])

def plot_op():
    num = 2  
    ax1= plt.subplot(gs1[num])
    variables_op1 = [x for x in corr_dict.keys()]
    values_op1 = np.array([x for x in corr_dict.values()])
    tornado(variables_op1,values_op1,ax1)
    return

plot_correlation()
plot_op()
plt.show()

非常感谢任何帮助。

0 个答案:

没有答案