r中的gbm二进制分类

时间:2017-04-16 15:07:28

标签: r machine-learning classification

我试图执行二进制分类,响应变量是“class”,取值为0或1.但是预测函数(yhat.boost1)的输出是一个连续变量。如何使其成为二进制?

set.seed(2016)

seismic1 <- read.csv("seismic.csv")
par(mfrow=c(1,2))

seismic1[,c(4:7,9:13,17:18)] <- seismic1[,c(4:7,9:13,17:18)]
seismic1 <- seismic1[,-(14:16)]

for(i in c(1:3,8)){
  seismic1[,i] <- as.numeric(seismic1[,i])
}

#set training and test data   
n <- dim(seismic1)[1]
p <- dim(seismic1)[2]

set.seed(2016)
test <- sample(n, round(n/4))
train <- (1:n)[-test]
seismic1.train <- seismic1[train,]
seismic1.test <- seismic1[test,]

#perform gbm classification 
?gbm
start.time <- proc.time()
boost.seismic1 =gbm(class~.,data=seismic1.train, distribution="bernoulli",n.trees =5000, interaction.depth =4)
summary(boost.seismic1)


#predict on the test dataset
yhat.boost1=predict (boost.seismic1,newdata =seismic1 [-train,],n.trees =500,type="response")
yhat.boost1

1 个答案:

答案 0 :(得分:0)

您只需要将截止值设置为0.5

 ifelse(yhat.boost1>0.5,1,0)

也就是说,如果概率大于50%,则应为1;如果概率小于50%,则应为0