H2ONaiveBayesEstimator预测意外结果

时间:2019-09-19 17:03:50

标签: python h2o

我有一个包含四列的数据框:'a','b','c','target',其中target是前三个值的总和。 我使用此数据帧训练H2ONaiveBayesEstimator,并尝试预测相似数据帧的结果。但是结果出乎意料。

我用LinearRegression尝试了相同的数据帧,它给出了预期的预测。

为什么H2ONaiveBayesEstimator给出这样的预测结果,即20行的值为27?

代码:

import pandas as pd
import h2o
from h2o.estimators import H2ONaiveBayesEstimator
from sklearn.linear_model import LinearRegression

train_df = pd.DataFrame(
    {
        "a": list(range(10)),
        "b": list(range(10)),
        "c": list(range(10)),
        "target": [i + i + i for i in list(range(10))]
    }
)

pred_df = pd.DataFrame(
    {
        "a": list(range(10, 30)),
        "b": list(range(10, 30)),
        "c": list(range(10, 30)),
    }
)

train_data = train_df.values
features_data = train_df.drop('target', axis=1).values
target_data = train_df['target'].values

pred_data = pred_df.values

lr = LinearRegression()
lr.fit(features_data, target_data)
predictions = lr.predict(pred_data)
print(predictions)

print('\n----\n')

h2o.init()
train_frame = h2o.H2OFrame(train_data, column_names=list(train_df.columns))
nbm = H2ONaiveBayesEstimator()
train_frame['target'] = train_frame['target'].asfactor()
nbm.train(y='target', training_frame=train_frame)

pred_frame = h2o.H2OFrame(pred_data, column_names=list(pred_df.columns))
predictions = nbm.predict(pred_frame)
print('\n----\n')
print(predictions.as_data_frame())

火车数据帧:

   a  b  c  target
0  0  0  0       0
1  1  1  1       3
2  2  2  2       6
3  3  3  3       9
4  4  4  4      12
5  5  5  5      15
6  6  6  6      18
7  7  7  7      21
8  8  8  8      24
9  9  9  9      27

pred数据帧:

0   10  10  10
1   11  11  11
2   12  12  12
3   13  13  13
4   14  14  14
5   15  15  15
6   16  16  16
7   17  17  17
8   18  18  18
9   19  19  19
10  20  20  20
11  21  21  21
12  22  22  22
13  23  23  23
14  24  24  24
15  25  25  25
16  26  26  26
17  27  27  27
18  28  28  28
19  29  29  29

H2ONaiveBayesEstimator的结果:

    predict             p0             p3  ...           p21           p24       p27
0        27   3.180305e-65   7.583358e-53  ...  6.076669e-06  1.098688e-02  0.989007
1        27   6.040575e-77   2.893040e-63  ...  1.522156e-08  5.527786e-04  0.999447
2        27   1.135940e-88   1.092736e-73  ...  3.775031e-11  2.753569e-05  0.999972
3        27  2.135088e-100   4.125332e-84  ...  9.357610e-14  1.370957e-06  0.999999
4        27  4.012965e-112   1.557370e-94  ...  2.319523e-16  6.825603e-08  1.000000
5        27  7.542484e-124  5.879283e-105  ...  5.749522e-19  3.398268e-09  1.000000
6        27  1.417632e-135  2.219508e-115  ...  1.425164e-21  1.691898e-10  1.000000
7        27  2.664479e-147  8.378943e-126  ...  3.532629e-24  8.423464e-12  1.000000
8        27  5.007966e-159  3.163164e-136  ...  8.756511e-27  4.193796e-13  1.000000
9        27  9.412616e-171  1.194137e-146  ...  2.170522e-29  2.087968e-14  1.000000
10       27  1.769128e-182  4.508027e-157  ...  5.380186e-32  1.039538e-15  1.000000
11       27  3.325128e-194  1.701841e-167  ...  1.333615e-34  5.175555e-17  1.000000
12       27  6.249673e-206  6.424678e-178  ...  3.305701e-37  2.576757e-18  1.000000
13       27  1.174644e-217  2.425402e-188  ...  8.194013e-40  1.282892e-19  1.000000
14       27  2.207777e-229  9.156221e-199  ...  2.031093e-42  6.387142e-21  1.000000
15       27  4.149581e-241  3.456597e-209  ...  5.034575e-45  3.179971e-22  1.000000
16       27  7.799257e-253  1.304912e-219  ...  1.247946e-47  1.583214e-23  1.000000
17       27  1.465893e-264  4.926217e-230  ...  3.093350e-50  7.882360e-25  1.000000
18       27  2.755188e-276  1.859713e-240  ...  7.667648e-53  3.924396e-26  1.000000
19       27  5.178455e-288  7.020668e-251  ...  1.900620e-55  1.953842e-27  1.000000

[20 rows x 11 columns]

LinearRegression的结果

[30. 33. 36. 39. 42. 45. 48. 51. 54. 57. 60. 63. 66. 69. 72. 75. 78. 81.
 84. 87.]

0 个答案:

没有答案