创建3D散点图python

时间:2019-09-22 17:51:53

标签: python numpy matplotlib scatter

我知道这个问题已经以不同的方式提出来了,但是我不知道为什么我的代码没有给我输出。

我正在使用虹膜训练数据集,如下所示:

import pandas as pd
df = pd.read_csv('https://archive.ics.uci.edu/ml/'
                  'machine-learning-databases/iris/iris.data',
                  header=None)

我使用以下数据:

# selecting setosa and versicolor
y = df.iloc[0:100, 4].values

y = np.where(y == 'Iris-setosa', -1, 1)

# extract sepal length, sepal width and petal length
X = df.iloc[0:100, [0, 1, 2]].values

ppn = Perceptron(eta0=0.1, max_iter=10)

然后,我使用以下辅助函数:

def plot_decision_regions_3(X, y, classifier, resolution=0.02):


    # plot the decision surface
    x1_min, x1_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    x2_min, x2_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    x3_min, x3_max = X[:, 2].min() - 1, X[:, 2].max() + 1

    xx1, xx2, xx3 = np.meshgrid(np.arange(x1_min, x1_max, resolution),
                           np.arange(x2_min, x2_max, resolution),
                           np.arange(x3_min, x3_max, resolution))

    Z = classifier.predict(np.array([xx1.ravel(), xx2.ravel(), xx3.ravel()]).T)

    Z = Z.reshape(xx1.shape)
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    ax.scatter(xx1, xx2, xx3)

    plt.xlim(xx1.min(), xx1.max())
    plt.ylim(xx2.min(), xx2.max())
    plt.set_zlim(xx3.min(), xx3.max())

最后,调用相同的名称:

plot_decision_regions_3(X, y, classifier=ppn)
plt.xlabel('sepal length [cm]')
plt.ylabel('petal length [cm]')
plt.legend(loc='upper left')
plt.show()

我知道这与我使用散点函数的方式有关,但是我无法弄清楚我的错误是什么。感谢您的帮助

谢谢!

0 个答案:

没有答案