如何在几个属性上学习(树)?

时间:2016-08-22 12:03:11

标签: python-3.x error-handling machine-learning tree scikit-learn

我在一个小小的虚拟机器学习问题中使用python和scikit-learn的树分类器。我有二进制结果变量(wc_measure),我相信它依赖于一些其他变量(cashcrisisindustry)。我尝试了以下方法:

#   import neccessary packages
import pandas as pd
import numpy as np
import sklearn as skl
from sklearn import tree
from sklearn.cross_validation import train_test_split as tts


#   import data and give a little overview
sample = pd.read_stata('sample_data.dta')

s = sample


#   What I want to learn on
X = [s.crisis, s.cash, s.industry]
y = s.wc_measure
X_train, X_test, y_train, y_test = tts(X, y, test_size = .5)


#let's learn a little

my_tree = tree.DecisionTreeClassifier()
clf = my_tree.fit(X_train, y_train)
predictions = my_tree.predict(X_test)

我收到以下错误:Number of labels=50 does not match number of samples=1。如果我将X基于单个变量(例如X = s.crisis),我会被要求重新塑造X。我不完全理解为什么我会遇到这些问题...想法?

PS:这是print(X)的返回

[0     4.0
1     4.0
2     5.0
3     3.0
4     4.0
5     2.0
6     2.0
7     1.0
8     3.0
9     3.0
10    4.0
11    3.0
12    2.0
13    4.0
14    5.0
15    4.0
16    2.0
17    2.0
18    3.0
19    2.0
20    5.0
21    4.0
22    2.0
23    4.0
24    5.0
25    1.0
26    5.0
27    3.0
28    4.0
29    2.0
     ... 
70    1.0
71    4.0
72    4.0
73    1.0
74    4.0
75    3.0
76    4.0
77    2.0
78    2.0
79    5.0
80    2.0
81    3.0
82    5.0
83    4.0
84    4.0
85    5.0
86    3.0
87    3.0
88    4.0
89    2.0
90    2.0
91    3.0
92    3.0
93    4.0
94    3.0
95    1.0
96    4.0
97    2.0
98    3.0
99    4.0
Name: crisis, dtype: float32, 0      450.283417
1      113.472214
2       11.811784
3     1007.507446
4      293.895142
5     1133.297729
6     2237.830322
7     1475.787109
8      283.363678
9      626.888794
10      38.865730
11     991.999390
12    1115.746948
13     373.537231
14      97.570717
15     136.079193
16    2560.691406
17     667.062073
18    1378.384521
19     152.716400
20       5.779267
21     481.511566
22     677.809631
23     722.521790
24      32.927990
25    2504.450928
26      17.422865
27     651.585083
28     549.469177
29     297.458527
         ...     
70    1198.370239
71     471.343933
72     389.709290
73    2962.622803
74     581.519287
75    1148.822388
76      67.653664
77    1346.391602
78    1764.086914
79      14.308219
80     973.152161
81     552.576904
82       2.863116
83     425.520752
84     321.773682
85      63.597332
86    1351.122559
87     735.856567
88     745.656677
89    2784.453125
90    1438.272705
91     768.780823
92     827.021423
93     591.778015
94     885.169434
95    1143.088867
96     399.816803
97    1517.454834
98    1311.692505
99     533.062561
Name: cash, dtype: float32, 0     5.0
1     2.0
2     3.0
3     5.0
4     4.0
5     3.0
6     5.0
7     1.0
8     1.0
9     2.0
10    1.0
11    5.0
12    2.0
13    4.0
14    6.0
15    2.0
16    6.0
17    2.0
18    5.0
19    1.0
20    3.0
21    4.0
22    2.0
23    6.0
24    4.0
25    4.0
26    3.0
27    3.0
28    5.0
29    1.0
     ... 
70    2.0
71    4.0
72    3.0
73    6.0
74    6.0
75    5.0
76    1.0
77    3.0
78    5.0
79    4.0
80    2.0
81    3.0
82    2.0
83    5.0
84    3.0
85    5.0
86    5.0
87    4.0
88    6.0
89    6.0
90    4.0
91    3.0
92    4.0
93    6.0
94    3.0
95    2.0
96    3.0
97    4.0
98    6.0
99    4.0

PPS:以下是我在Stata中生成数据的方法:

clear matrix
clear all
set more off

set obs 100
gen id = _n


*Basics
    gen industry = round(runiform()*5+1)
    gen activity = round(runiform()*5+1)
    gen crisis = round(runiform()*4+1)
        egen min_crisis = min(crisis)
        egen max_crisis = max(crisis)
        gen n_crisis = (crisis - min_crisis)/(max_crisis-min_crisis)

*Company details
    gen staff = round((0.5 * industry + 0.3 * activity - 0.2 * crisis) * runiform()*100+1) 

    gen revenue = (0.5 * industry + 0.2 * activity - 0.3 * crisis ) * 1000 + runiform()
        replace revenue = 0 if revenue<0

    *Working Capital (wc)
    gen stock = runiform()*0.5*crisis*revenue
    gen receivables = runiform()*0.5*crisis*revenue
    gen payables = runiform()*-0.5*crisis*revenue
        replace payables = 0 if payables < 0
    gen wc = stock + receivables - payables 
        egen avg_wc = mean(wc), by(industry)


    *Liquidity
    gen loan = (0.5 * industry + 0.2 * activity - 0.3 * crisis ) * 1000 + runiform()
        replace loan = 0 if loan<0
        egen pc_loan = pctile(loan), p(0.2) by(industry)
        replace loan = 0 if loan<pc_loan

    gen current_debt = n_crisis * loan + runiform()*100

    gen cash = (1-n_crisis)*revenue + runiform()*100


*Measures

    *WC-measure (binary)
        gen wc_status = (wc-avg_wc)
            egen max_wc_status = max(wc_status), by(industry)
            egen min_wc_status = min(wc_status), by(industry)
            gen n_wc_status = (wc_status - min_wc_status) / (max_wc_status-min_wc_status)
    gen wc_measure = round(n_wc_status)

2 个答案:

答案 0 :(得分:1)

你需要检查X是否是tts的正确输入? X有三行N列。 X应该有N行,有3个属性。这就是为什么它抱怨数字不匹配的原因。

答案 1 :(得分:0)

我终于解决了这个问题。问题是我没有将我的样本s定义为数组 - 相应地,X是一个列表。谢谢大家的帮忙!

这是我做的:

#   import data and give a little overview
sample = pd.read_stata('sample_data.dta')
s = sample
print(s.shape)


#   Have some mor vars and an array of explanatory vars


X = np.array((s.crisis, s.cash, s.industry)).reshape(100, 3)
y = np.array(s.wc_measure)
X_train, X_test, y_train, y_test = tts(X, y, test_size = .8)


#let's learn a little

my_tree = tree.DecisionTreeClassifier()
clf = my_tree.fit(X_train, y_train)
predictions = my_tree.predict(X_test)