如何使用python绘制多个子图

时间:2019-09-27 11:27:54

标签: python numpy matplotlib scikit-learn

我遇到一个问题,我被赋予了图像,必须使用python和matplotlib,sklearn,numpy重新创建该图像。以下是图片:

Result Wanted Picture

这是我到目前为止在python中编写的代码:

import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
import numpy as np

iris = load_iris()
print(type(iris))
print(iris['target_names'])
print(iris['data'])
print(iris['target'])
print(iris['DESCR'])
print(iris['feature_names'])

fig = plt.figure()
ax1 = plt.subplot(2, 1, 1)
ax2 = plt.subplot(2, 1, 2)


iris = load_iris()
data = np.array(iris['data'])
targets = np.array(iris['target'])

cd = {0: 'r', 1: 'b', 2: 'g'}

cols = np.array([cd[target] for target in targets])

ax1.scatter(data[:, 0], data[:, 1], c=cols)
ax2.scatter(data[:, 0], data[:, 2], c=cols)
plt.show()

我完全迷路了,真的需要帮助才能克服这个困难,我只正确地正确完成了前2个子图。任何建议都将非常有帮助,因为我几天前一直在努力找出这一点。

3 个答案:

答案 0 :(得分:0)

获得带有相应子批次的图形的一种方法是

fig, subs = plt.subplots(4,3)
然后

subs是ares的二维数组,因此您可以执行以下操作:

subs[0][0].scatter(x,y)

答案 1 :(得分:0)

这是一个例子

from matplotlib import pyplot as plt 
import numpy as np

x = np.linspace(-5, 5, 10)
y = np.random.rand(10)

fig, ax = plt.subplots(nrows=4, ncols=3, figsize=(8, 6))

# ax is a 2d array with shape (4, 3), it can be sliced just like a numpy array 

for row in range(4):
    for col in range(3):
        ax[row][col].scatter(x, y, c='color you want')

plt.show()

答案 2 :(得分:0)

这有望确切解释如何创建所需的图像:

import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
import numpy as np

fig, subs = plt.subplots(4,3) #setting the shape of the figure in one line as opposed to creating 12 variables 

iris = load_iris() ##code as per the example 
data = np.array(iris['data'])
targets = np.array(iris['target'])

cd = {0:'r',1:'b',2:"g"}
cols = np.array([cd[target] for target in targets])


# Row 1 

subs[0][0].scatter(data[:,0], data[:,1], c=cols)
subs[0][1].scatter(data[:,0], data[:,2], c=cols)
subs[0][2].scatter(data[:,0], data[:,3], c=cols)

# Row 2 

subs[1][0].scatter(data[:,1], data[:,0], c=cols)
subs[1][1].scatter(data[:,1], data[:,2], c=cols)
subs[1][2].scatter(data[:,1], data[:,3], c=cols)

# Row 3 

subs[2][0].scatter(data[:,2], data[:,0], c=cols)
subs[2][1].scatter(data[:,2], data[:,1], c=cols)
subs[2][2].scatter(data[:,2], data[:,3], c=cols)

#Row 4 

subs[3][0].scatter(data[:,3], data[:,0], c=cols)
subs[3][1].scatter(data[:,3], data[:,1], c=cols)
subs[3][2].scatter(data[:,3], data[:,2], c=cols)

plt.show()