我有一个包含四列的数据框:'a','b','c','target',其中target是前三个值的总和。 我使用此数据帧训练H2ONaiveBayesEstimator,并尝试预测相似数据帧的结果。但是结果出乎意料。
我用LinearRegression尝试了相同的数据帧,它给出了预期的预测。
为什么H2ONaiveBayesEstimator给出这样的预测结果,即20行的值为27?
代码:
import pandas as pd
import h2o
from h2o.estimators import H2ONaiveBayesEstimator
from sklearn.linear_model import LinearRegression
train_df = pd.DataFrame(
{
"a": list(range(10)),
"b": list(range(10)),
"c": list(range(10)),
"target": [i + i + i for i in list(range(10))]
}
)
pred_df = pd.DataFrame(
{
"a": list(range(10, 30)),
"b": list(range(10, 30)),
"c": list(range(10, 30)),
}
)
train_data = train_df.values
features_data = train_df.drop('target', axis=1).values
target_data = train_df['target'].values
pred_data = pred_df.values
lr = LinearRegression()
lr.fit(features_data, target_data)
predictions = lr.predict(pred_data)
print(predictions)
print('\n----\n')
h2o.init()
train_frame = h2o.H2OFrame(train_data, column_names=list(train_df.columns))
nbm = H2ONaiveBayesEstimator()
train_frame['target'] = train_frame['target'].asfactor()
nbm.train(y='target', training_frame=train_frame)
pred_frame = h2o.H2OFrame(pred_data, column_names=list(pred_df.columns))
predictions = nbm.predict(pred_frame)
print('\n----\n')
print(predictions.as_data_frame())
火车数据帧:
a b c target
0 0 0 0 0
1 1 1 1 3
2 2 2 2 6
3 3 3 3 9
4 4 4 4 12
5 5 5 5 15
6 6 6 6 18
7 7 7 7 21
8 8 8 8 24
9 9 9 9 27
pred数据帧:
0 10 10 10
1 11 11 11
2 12 12 12
3 13 13 13
4 14 14 14
5 15 15 15
6 16 16 16
7 17 17 17
8 18 18 18
9 19 19 19
10 20 20 20
11 21 21 21
12 22 22 22
13 23 23 23
14 24 24 24
15 25 25 25
16 26 26 26
17 27 27 27
18 28 28 28
19 29 29 29
H2ONaiveBayesEstimator的结果:
predict p0 p3 ... p21 p24 p27
0 27 3.180305e-65 7.583358e-53 ... 6.076669e-06 1.098688e-02 0.989007
1 27 6.040575e-77 2.893040e-63 ... 1.522156e-08 5.527786e-04 0.999447
2 27 1.135940e-88 1.092736e-73 ... 3.775031e-11 2.753569e-05 0.999972
3 27 2.135088e-100 4.125332e-84 ... 9.357610e-14 1.370957e-06 0.999999
4 27 4.012965e-112 1.557370e-94 ... 2.319523e-16 6.825603e-08 1.000000
5 27 7.542484e-124 5.879283e-105 ... 5.749522e-19 3.398268e-09 1.000000
6 27 1.417632e-135 2.219508e-115 ... 1.425164e-21 1.691898e-10 1.000000
7 27 2.664479e-147 8.378943e-126 ... 3.532629e-24 8.423464e-12 1.000000
8 27 5.007966e-159 3.163164e-136 ... 8.756511e-27 4.193796e-13 1.000000
9 27 9.412616e-171 1.194137e-146 ... 2.170522e-29 2.087968e-14 1.000000
10 27 1.769128e-182 4.508027e-157 ... 5.380186e-32 1.039538e-15 1.000000
11 27 3.325128e-194 1.701841e-167 ... 1.333615e-34 5.175555e-17 1.000000
12 27 6.249673e-206 6.424678e-178 ... 3.305701e-37 2.576757e-18 1.000000
13 27 1.174644e-217 2.425402e-188 ... 8.194013e-40 1.282892e-19 1.000000
14 27 2.207777e-229 9.156221e-199 ... 2.031093e-42 6.387142e-21 1.000000
15 27 4.149581e-241 3.456597e-209 ... 5.034575e-45 3.179971e-22 1.000000
16 27 7.799257e-253 1.304912e-219 ... 1.247946e-47 1.583214e-23 1.000000
17 27 1.465893e-264 4.926217e-230 ... 3.093350e-50 7.882360e-25 1.000000
18 27 2.755188e-276 1.859713e-240 ... 7.667648e-53 3.924396e-26 1.000000
19 27 5.178455e-288 7.020668e-251 ... 1.900620e-55 1.953842e-27 1.000000
[20 rows x 11 columns]
LinearRegression的结果
[30. 33. 36. 39. 42. 45. 48. 51. 54. 57. 60. 63. 66. 69. 72. 75. 78. 81.
84. 87.]