如何在H2O GBM和DRF中更改预测

时间:2018-07-10 10:23:48

标签: python-3.x random-forest h2o gbm

我正在h2o DRF和GBM中建立分类模型。我想更改预测的概率,以便如果p0 <0.2则predict = 0,否则predict = 1

1 个答案:

答案 0 :(得分:1)

当前,您需要手动执行此操作。如果我们为threshold方法使用predict()自变量会更容易些,所以我创建了JIRA ticket票证以使其更加简单。

请参见下面的Python示例,了解如何手动执行此操作。

import h2o
from h2o.estimators.gbm import H2OGradientBoostingEstimator
h2o.init()

# Import a sample binary outcome train/test set into H2O
train = h2o.import_file("https://s3.amazonaws.com/erin-data/higgs/higgs_train_10k.csv")
test = h2o.import_file("https://s3.amazonaws.com/erin-data/higgs/higgs_test_5k.csv")

# Identify predictors and response
x = train.columns
y = "response"
x.remove(y)

# For binary classification, response should be a factor
train[y] = train[y].asfactor()
test[y] = test[y].asfactor()

# Train and cross-validate a GBM
my_gbm = H2OGradientBoostingEstimator(distribution="bernoulli", seed=1)
my_gbm.train(x=x, y=y, training_frame=train)

# Predict on a test set using default threshold
pred = my_gbm.predict(test_data=test)

看一下pred框架:

In [16]: pred.tail()
Out[16]:
  predict        p0        p1
---------  --------  --------
        1  0.484712  0.515288
        0  0.693893  0.306107
        1  0.319674  0.680326
        0  0.582344  0.417656
        1  0.471658  0.528342
        1  0.079922  0.920078
        1  0.150146  0.849854
        0  0.835288  0.164712
        0  0.639877  0.360123
        1  0.54377   0.45623

[10 rows x 3 columns]

以下是手动创建所需预测的方法。 H2O User Guide中提供了有关如何切片H2OFrame的更多信息。

# Binary column which is 1 if >=0.2 and 0 if <0.2
newpred = pred["p1"] >= 0.2 

newpred.tail()

查看二进制列:

In [23]: newpred.tail()
Out[23]:
  p1
----
   1
   1
   1
   1
   1
   1
   1
   0
   1
   1

[10 rows x 1 column]

现在您有了想要的预测。您也可以将"predict"列替换为新的预测标签。

pred["predict"] = newpred

现在重新检查pred框架:

In [24]: pred.tail()
Out[24]:
  predict        p0        p1
---------  --------  --------
        1  0.484712  0.515288
        1  0.693893  0.306107
        1  0.319674  0.680326
        1  0.582344  0.417656
        1  0.471658  0.528342
        1  0.079922  0.920078
        1  0.150146  0.849854
        0  0.835288  0.164712
        1  0.639877  0.360123
        1  0.54377   0.45623

[10 rows x 3 columns]