Scikit决策树错误

时间:2015-11-04 02:58:45

标签: python numpy machine-learning scikit-learn decision-tree

我在尝试可视化测试数据的预测标签时遇到错误。

import scipy.io
from sklearn import tree
import numpy as np
import pydot
from sklearn.externals.six import StringIO



  def main():
        mat  = scipy.io.loadmat('m.mat')
        x = np.array(mat['m'][:,:3])
        y = np.array(mat['m'][:,3])
        print ("training data: ",x)
        print ("training label: ", y)
        testdata = np.loadtxt('testing.txt')
        print ("testing data: ",testdata)
        exit
        clf = tree.DecisionTreeClassifier(min_samples_split=20, random_state=99, criterion = "entropy")
        clf = clf.fit(x, y)
        clf = clf.predict(testdata)
        print ("NewLabel:",clf)
        import pydot
        out = StringIO()
        tree.export_graphviz(clf, out_file=out)
        pydot.graph_from_dot_data(out.getvalue()).write_pdf("output.pdf")

输出标签的格式正确:

('training data: ', array([[  1.00000000e+00,   6.00000000e+00,   1.00000000e+00],
       [  1.00000000e+00,   1.60000000e+02,   3.10000000e+01],
       [  1.00000000e+00,   4.38000000e+02,   1.00000000e+00],
       ..., 
       [  1.84200000e+03,   2.60830000e+04,   7.80000000e-01],
       [  1.84200000e+03,   2.61810000e+04,   5.20000000e-01],
       [  1.84200000e+03,   2.61840000e+04,   1.00000000e+00]]))
('training label: ', array([ 1.,  1.,  1., ...,  1.,  1.,  1.]))
('testing data: ', array([[  1.00000000e+00,   1.51000000e+02,   3.00000000e+00],
       [  1.00000000e+00,   1.60000000e+02,   3.30000000e+01],
       [  1.00000000e+00,   2.65000000e+02,   3.00000000e+00],
       ..., 
       [  9.52000000e+02,   2.59300000e+04,   2.60000000e+01],
       [  9.52000000e+02,   2.60830000e+04,   9.60000000e-01],
       [  9.52000000e+02,   2.61810000e+04,   1.00000000e+00]]))
('NewLabel:', array([ 1.,  1.,  1., ...,  1.,  1.,  1.]))

我得到了这样的错误:

Traceback (most recent call last):
  File "DT.py", line 31, in <module>
    main()
  File "DT.py", line 24, in main
    tree.export_graphviz(clf, out_file=out)
  File "/usr/local/lib/python2.7/dist-packages/scikit_learn-0.14.1-py2.7-linux-x86_64.egg/sklearn/tree/export.py", line 131, in export_graphviz
    recurse(decision_tree.tree_, 0)
AttributeError: 'numpy.ndarray' object has no attribute 'tree_'

有谁知道如何解决这个问题?谢谢!

0 个答案:

没有答案