在python中使用pyplot绘制多个图

时间:2018-09-22 14:39:08

标签: python matplotlib scikit-learn

我想在数据中所有特征组合之间绘制散点图。为此,我使用下面的代码,但是我得到了重叠的图。

#importing the important libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn import svm
from sklearn.cross_validation import train_test_split
from sklearn import metrics
from sklearn import datasets

wine_data = datasets.load_wine()

#exploring the ralationship between the data by visualizing it.
i = 1
plt.figure(figsize=(15,15))
for feature_x_pos,feature_x in enumerate(wine_data.feature_names):
  for feature_y_pos,feature_y in enumerate(wine_data.feature_names):
    if feature_x_pos != feature_y_pos:
      plt.subplot(60,3,i)
      plt.scatter(wine_data.data[:,feature_x_pos],wine_data.data[:,feature_y_pos],c = wine_data.target, cmap = 'jet')
      plt.xlabel(feature_x)
      plt.ylabel(feature_y)
      i=i+1

葡萄酒数据包含13个功能。我想在所有特征对之间绘制散点图。 上面代码的输出如下:

enter image description here

我正在使用Google colab编写代码。

请帮助避免图形重叠。

2 个答案:

答案 0 :(得分:0)

两种解决方案:

1。。请尝试在代码末尾添加plt.tight_layout(),以消除重叠。

    i = 1
    plt.figure(figsize=(15,15))
    for feature_x_pos,feature_x in enumerate(wine_data.feature_names):
      for feature_y_pos,feature_y in enumerate(wine_data.feature_names):
        if feature_x_pos != feature_y_pos:
          plt.subplot(60,3,i)
          plt.scatter(wine_data.data[:,feature_x_pos],wine_data.data[:,feature_y_pos],c = wine_data.target, cmap = 'jet')
          plt.xlabel(feature_x)
          plt.ylabel(feature_y)
          i=i+1;

plt.tight_layout()

2。。创建180个图形,而不是包含180个图形。

    i = 1

    for feature_x_pos,feature_x in enumerate(wine_data.feature_names):
      for feature_y_pos,feature_y in enumerate(wine_data.feature_names):
        if feature_x_pos != feature_y_pos:
          fig, ax = plt.subplots(1,1)
          ax.scatter(wine_data.data[:,feature_x_pos],wine_data.data[:,feature_y_pos],c = wine_data.target, cmap = 'jet')
          ax.set_xlabel(feature_x)
          ax.set_ylabel(feature_y)
          fig.show()
          i=i+1;

答案 1 :(得分:0)

我得到了解决方案,只是增加了图形的长度。

#exploring the ralationship between the data by visualizing it.
i = 1
plt.figure(figsize=(15,200)) #changed the length from 15 to 200
for feature_x_pos,feature_x in enumerate(wine_data.feature_names):
  for feature_y_pos,feature_y in enumerate(wine_data.feature_names):
    if feature_x_pos != feature_y_pos:
      plt.subplot(60,3,i)
      plt.scatter(wine_data.data[:,feature_x_pos],wine_data.data[:,feature_y_pos],c = wine_data.target, cmap = 'jet')
      plt.xlabel(feature_x)
      plt.ylabel(feature_y)
      i=i+1

谢谢大家的评论和指导:)