我正在使用H2O deeplearning前馈深度神经网络进行二进制分类。我的课程非常不平衡,我想使用像
这样的参数balance_classes,class_sampling_factors
任何机构都可以给我一个可重复的例子,说明如何专门初始化这些参数来处理类不平衡问题。
答案 0 :(得分:5)
首先,这是完整的,可重现的例子:
library(h2o)
h2o.init()
data(iris) #Not required?
iris <- iris[1:120,] #Remove 60% of virginica
summary(iris$Species) #50/50/20
d <- as.h2o(iris)
splits = h2o.splitFrame(d,0.8,c("train","test"), seed=77)
train = splits[[1]]
test = splits[[2]]
summary(train$Species) #41/41/14
summary(test$Species) #9/9/6
m1 = h2o.randomForest(1:4, 5, train, model_id ="RF_defaults", seed=1)
h2o.confusionMatrix(m1)
m2 = h2o.randomForest(1:4, 5, train, model_id ="RF_balanced", seed=1,
balance_classes = TRUE)
h2o.confusionMatrix(m2)
m3 = h2o.randomForest(1:4, 5, train, model_id ="RF_balanced", seed=1,
balance_classes = TRUE,
class_sampling_factors = c(1, 1, 2.5)
)
h2o.confusionMatrix(m3)
第一行初始化H2O,然后我故意修改虹膜数据集,扔掉3个类别中的一个的60%,以造成不平衡。
接下来的几行将数据加载到H2O中,并创建80%/ 20%的列车/测试数据分割。故意选择种子,因此在训练数据 virginica 中有14.58%的数据,而原始数据为16.67%,测试数据为25%。
然后我训练了三个随机森林模型。m1
是默认值,其混淆矩阵如下所示:
setosa versicolor virginica Error Rate
setosa 41 0 0 0.0000 = 0 / 41
versicolor 0 39 2 0.0488 = 2 / 41
virginica 0 1 13 0.0714 = 1 / 14
Totals 41 40 15 0.0312 = 3 / 96
这里没什么可看的:它使用它找到的数据。
现在这里是m2
的相同输出,它会打开balance_classes
。你可以看到它被 virginica 类过度采样,以使它们尽可能平衡。 (最右边的列表示41,41,40而不是前一输出中的41,41,14。)
setosa versicolor virginica Error Rate
setosa 41 0 0 0.0000 = 0 / 41
versicolor 0 41 0 0.0000 = 0 / 41
virginica 0 2 38 0.0500 = 2 / 40
Totals 41 43 38 0.0164 = 2 / 122
在m3
我们仍然会打开balance_classes
,但也会告诉它情况的真相。即实际数据是16.67% virginica ,而不是它在train
数据中看到的14.58%。 m3
的混淆矩阵表明它将14个 virginica 样本变为37个样本而不是40个样本。
setosa versicolor virginica Error Rate
setosa 41 0 0 0.0000 = 0 / 41
versicolor 0 41 0 0.0000 = 0 / 41
virginica 0 2 35 0.0541 = 2 / 37
Totals 41 43 35 0.0168 = 2 / 119
我怎么知道写c(1, 1, 2.5)
,而不是c(2.5, 1, 1)
或c(1, 2.5, 1)
?文档说它必须是&#34;词典顺序&#34;。您可以找到该订单的用途:
h2o.levels(train$Species)
告诉我:
[1] "setosa" "versicolor" "virginica"
意见位:balance_classes
可以开启,但只有当您有充分的理由相信您的培训数据不具代表性时,才应使用class_sampling_factors
。
注意:根据我即将出版的“H2O实用机器学习”一书改编的代码和说明。