我在一个小小的虚拟机器学习问题中使用python和scikit-learn的树分类器。我有二进制结果变量(wc_measure
),我相信它依赖于一些其他变量(cash
,crisis
和industry
)。我尝试了以下方法:
# import neccessary packages
import pandas as pd
import numpy as np
import sklearn as skl
from sklearn import tree
from sklearn.cross_validation import train_test_split as tts
# import data and give a little overview
sample = pd.read_stata('sample_data.dta')
s = sample
# What I want to learn on
X = [s.crisis, s.cash, s.industry]
y = s.wc_measure
X_train, X_test, y_train, y_test = tts(X, y, test_size = .5)
#let's learn a little
my_tree = tree.DecisionTreeClassifier()
clf = my_tree.fit(X_train, y_train)
predictions = my_tree.predict(X_test)
我收到以下错误:Number of labels=50 does not match number of samples=1
。如果我将X
基于单个变量(例如X = s.crisis
),我会被要求重新塑造X
。我不完全理解为什么我会遇到这些问题...想法?
PS:这是print(X)的返回
[0 4.0
1 4.0
2 5.0
3 3.0
4 4.0
5 2.0
6 2.0
7 1.0
8 3.0
9 3.0
10 4.0
11 3.0
12 2.0
13 4.0
14 5.0
15 4.0
16 2.0
17 2.0
18 3.0
19 2.0
20 5.0
21 4.0
22 2.0
23 4.0
24 5.0
25 1.0
26 5.0
27 3.0
28 4.0
29 2.0
...
70 1.0
71 4.0
72 4.0
73 1.0
74 4.0
75 3.0
76 4.0
77 2.0
78 2.0
79 5.0
80 2.0
81 3.0
82 5.0
83 4.0
84 4.0
85 5.0
86 3.0
87 3.0
88 4.0
89 2.0
90 2.0
91 3.0
92 3.0
93 4.0
94 3.0
95 1.0
96 4.0
97 2.0
98 3.0
99 4.0
Name: crisis, dtype: float32, 0 450.283417
1 113.472214
2 11.811784
3 1007.507446
4 293.895142
5 1133.297729
6 2237.830322
7 1475.787109
8 283.363678
9 626.888794
10 38.865730
11 991.999390
12 1115.746948
13 373.537231
14 97.570717
15 136.079193
16 2560.691406
17 667.062073
18 1378.384521
19 152.716400
20 5.779267
21 481.511566
22 677.809631
23 722.521790
24 32.927990
25 2504.450928
26 17.422865
27 651.585083
28 549.469177
29 297.458527
...
70 1198.370239
71 471.343933
72 389.709290
73 2962.622803
74 581.519287
75 1148.822388
76 67.653664
77 1346.391602
78 1764.086914
79 14.308219
80 973.152161
81 552.576904
82 2.863116
83 425.520752
84 321.773682
85 63.597332
86 1351.122559
87 735.856567
88 745.656677
89 2784.453125
90 1438.272705
91 768.780823
92 827.021423
93 591.778015
94 885.169434
95 1143.088867
96 399.816803
97 1517.454834
98 1311.692505
99 533.062561
Name: cash, dtype: float32, 0 5.0
1 2.0
2 3.0
3 5.0
4 4.0
5 3.0
6 5.0
7 1.0
8 1.0
9 2.0
10 1.0
11 5.0
12 2.0
13 4.0
14 6.0
15 2.0
16 6.0
17 2.0
18 5.0
19 1.0
20 3.0
21 4.0
22 2.0
23 6.0
24 4.0
25 4.0
26 3.0
27 3.0
28 5.0
29 1.0
...
70 2.0
71 4.0
72 3.0
73 6.0
74 6.0
75 5.0
76 1.0
77 3.0
78 5.0
79 4.0
80 2.0
81 3.0
82 2.0
83 5.0
84 3.0
85 5.0
86 5.0
87 4.0
88 6.0
89 6.0
90 4.0
91 3.0
92 4.0
93 6.0
94 3.0
95 2.0
96 3.0
97 4.0
98 6.0
99 4.0
PPS:以下是我在Stata中生成数据的方法:
clear matrix
clear all
set more off
set obs 100
gen id = _n
*Basics
gen industry = round(runiform()*5+1)
gen activity = round(runiform()*5+1)
gen crisis = round(runiform()*4+1)
egen min_crisis = min(crisis)
egen max_crisis = max(crisis)
gen n_crisis = (crisis - min_crisis)/(max_crisis-min_crisis)
*Company details
gen staff = round((0.5 * industry + 0.3 * activity - 0.2 * crisis) * runiform()*100+1)
gen revenue = (0.5 * industry + 0.2 * activity - 0.3 * crisis ) * 1000 + runiform()
replace revenue = 0 if revenue<0
*Working Capital (wc)
gen stock = runiform()*0.5*crisis*revenue
gen receivables = runiform()*0.5*crisis*revenue
gen payables = runiform()*-0.5*crisis*revenue
replace payables = 0 if payables < 0
gen wc = stock + receivables - payables
egen avg_wc = mean(wc), by(industry)
*Liquidity
gen loan = (0.5 * industry + 0.2 * activity - 0.3 * crisis ) * 1000 + runiform()
replace loan = 0 if loan<0
egen pc_loan = pctile(loan), p(0.2) by(industry)
replace loan = 0 if loan<pc_loan
gen current_debt = n_crisis * loan + runiform()*100
gen cash = (1-n_crisis)*revenue + runiform()*100
*Measures
*WC-measure (binary)
gen wc_status = (wc-avg_wc)
egen max_wc_status = max(wc_status), by(industry)
egen min_wc_status = min(wc_status), by(industry)
gen n_wc_status = (wc_status - min_wc_status) / (max_wc_status-min_wc_status)
gen wc_measure = round(n_wc_status)
答案 0 :(得分:1)
你需要检查X是否是tts的正确输入? X有三行N列。 X应该有N行,有3个属性。这就是为什么它抱怨数字不匹配的原因。
答案 1 :(得分:0)
我终于解决了这个问题。问题是我没有将我的样本s
定义为数组 - 相应地,X
是一个列表。谢谢大家的帮忙!
这是我做的:
# import data and give a little overview
sample = pd.read_stata('sample_data.dta')
s = sample
print(s.shape)
# Have some mor vars and an array of explanatory vars
X = np.array((s.crisis, s.cash, s.industry)).reshape(100, 3)
y = np.array(s.wc_measure)
X_train, X_test, y_train, y_test = tts(X, y, test_size = .8)
#let's learn a little
my_tree = tree.DecisionTreeClassifier()
clf = my_tree.fit(X_train, y_train)
predictions = my_tree.predict(X_test)