如何在csv文件中读取scikit-learn树

时间:2013-12-11 21:38:07

标签: python numpy scikit-learn

我对Rpython直接熟悉,但不熟悉scikit-learnnumpy

我在http://scikit-learn.org/stable/modules/tree.html#treescikit-learn

找到了from sklearn.datasets import load_iris from sklearn import tree iris = load_iris() clf = tree.DecisionTreeClassifier() clf = clf.fit(iris.data, iris.target) 中包含的虹膜数据集的以下代码
scikit-learn

我不想使用属于A,B,C,D 5.1,3.5,1.4,0.2 4.9,3.0,1.4,0.2 4.7,3.2,1.3,0.2 4.6,3.1,1.5,0.2 ......... 的iris,而是想加载以下格式的csv文件:

clf.fit(?,?)

如何加载它,将其加入iris.target以及我需要使用哪些代替{{1}}?

2 个答案:

答案 0 :(得分:0)

我建议你使用pandas。它实现了类似于R数据帧的东西。在将数据帧与sklearn一起使用之前,您需要将数据帧转换为numpy数组(np.array(df))

答案 1 :(得分:0)

data.txt中:

a,b,c,d
5.1,3.5,1.4,0.2
4.9,3.0,1.4,0.2
4.7,3.2,1.3,0.2
4.6,3.1,1.5,0.2

要加载数据,您可以使用numpy.loadtxt:

    import numpy as np
    from sklearn import tree

    mydata=np.loadtxt('data.txt',dtype=np.object,delimiter=',')
    mydata=mydata[1:].astype(np.float) # Perform conversion (for quantitative features only)

    clf=tree.DecisionTreeClassifier()

    #According to sklearn documentation we should map all class marks to integers
    #Lets do it:
    translation_table={'mark1':1,'mark2':2,'mark3':3} #or {'setosa': 1, 'virginia' :2} etc.
    target_data=['mark1','mark2','mark1','mark3', ] #etc.
    int_target_data=map(lambda x: translation_table[x],target_data) # Perform mapping needed by sklearn classifiers
    clf.fit(mydata,int_target_data) # train your classifier