R一类SVM-获取概率输出

时间:2020-05-11 16:00:18

标签: r machine-learning svm libsvm one-class-classification

当我从R中的一类svm进行预测时,我试图找出概率输出。我知道libsvm不支持此功能,我也知道已经问过这个问题beforehere是几年前在SO上使用的,但那时还没有软件包。我希望一切都变了!同样,这个问题仍然有效,因为没有提供R中实现的方法作为解决方案。

我找不到执行此操作的软件包,因此我尝试了两种方法来解决此问题:

  1. 获取决策值并通过使用S型激活函数对其进行转换。此paper中对此进行了描述。请注意以下段落:

此外,SVM还可以产生类概率作为输出而不是类标签。这个 可以通过改进普拉特后代的实现(Lin,Lin和Weng 2001)来完成 将S型函数拟合到二进制SVM分类器的决策值f的概率(Platt 2000),通过最小化对数似然函数来估计A和B

  1. 对预测的输出使用逻辑回归函数,并从中得出概率。 Platt首先描述了这种方法,概述了here

我的问题是,要检查我的两个解决方案中的任何一个是否合理,我针对两个类别的svm问题测试了这两种方法,因为e1071使用libsvm给出了两个类别的概率问题,因此被视为“真相”。我发现我的方法都没有与libsvm紧密结合。

这是三个图表,显示了结果概率与已知决策值的关系。 Click to see image. Sorry I seem to have too low a reputation to embed the image which is frustrating! I'm not sure if someone in the community with a higher reputation can edit to embed?

我认为我的普拉特方法在理论上更合理,但是,从图中可以看出,逻辑回归似乎在某种程度上太好了,与这两种分类相关的概率对于正值非常接近1,对于负值则非常接近0。

我为Platt实现的代码是

platt_scale <- function(oc_svm, X){
  # Get SVM predictions
  y_pred <- predict(oc_svm$best.model,X)
  #y_pred <- as.factor(ifelse(y_pred==T,"pos","neg"))
  # Train using logistic regression with cross-validation 
  require(caret)
  model <- train(x = X,
                 y = y_pred,
                 method = "glm",
                 family=binomial(),
                 trControl = trainControl(method = "cv",
                                          number = 5),
                 control = list(maxit = 50) #BROUGHT IN TO STOP WARNING MESSAGES
  )
  return(predict(model,
                 newdata = X,
                 type = "prob")[,1]) 
}

运行时我收到以下警告

glm.fit: fitted probabilities numerically 0 or 1 occurred

所以我显然做错了!我觉得修复此功能可能是最好的方法,但是我看不出哪里出了问题?我正在遵循我之前提到的方法here

我得到如下决策值的Sigmoid

sig_mult <-e1071::sigmoid(decision_values)

示例是使用Iris数据集完成的,完整代码在这里

data(iris)
two_class<-iris[iris$Species %in% c("setosa","versicolor"),]

#Make Two-class SVM 
svm_mult<-e1071::tune(svm,
                train.x = two_class[,1:4],
                train.y = factor(two_class[,5],levels=c("setosa", "versicolor")),
                type="C-classification", 
                kernel="radial",
                gamma=0.05, 
                cost=1,  
                probability = T,
                tunecontrol = tune.control(cross = 5))

#Get related decision values
dec_vals_mult <-attr(predict(svm_mult$best.model, 
                  two_class[,1:4], 
                  decision.values =  T #use decision values to get score
                  ), "decision.values")
#Get related probabilities
prob_mult <-attr(predict(svm_mult$best.model, 
                             two_class[,1:4], 
                             probability =  T #use decision values to get score
), "probabilities")[,1]

#transform decision values using sigmoid
sig_mult <-e1071::sigmoid(dec_vals_mult)
#Use Platt Implementation function to derive probabilities
platt_imp<-platt_scale(svm_mult,two_class[,1:4])

require(ggplot2)
data2<-as.data.frame(cbind(dec_vals_mult,sig_mult))
names(data2)<-c("Decision.Values","Sigmoid.Decision.Values(Prob)")
sig<-ggplot(data=data2,aes(x=Decision.Values,
                     y=`Sigmoid.Decision.Values(Prob)`,
                     colour=ifelse(Decision.Values<0,"neg","pos")))+
  geom_point()+
  ylim(0,1)+
  theme(legend.position = "none")

data3<-as.data.frame(cbind(dec_vals_mult,prob_mult))
names(data3)<-c("Decision.Values","Probabilities")
actual<-ggplot(data=data3,aes(x=Decision.Values,
                      y=Probabilities,
                      colour=ifelse(Decision.Values<0,"neg","pos")))+
  geom_point()+
  ylim(0,1)+
  theme(legend.position = "none")


data4<-as.data.frame(cbind(dec_vals_mult,platt_imp))
names(data4)<-c("Decision.Values","Platt")
plat_imp<-ggplot(data=data4,aes(x=Decision.Values,
                      y=Platt,
                      colour=ifelse(Decision.Values<0,"neg","pos")))+
  geom_point()+
  ylim(0,1)

require(ggpubr)
ggarrange(actual, plat_imp, sig,
          labels = c("Actual", "Platt Implementation", "Sigmoid Transformation"),
          ncol = 3,
          label.x = -.05,
          label.y = 1.001,
          font.label = list(size = 8.5, color = "black", face = "bold", family = NULL),
          common.legend = TRUE, legend = "bottom")

0 个答案:

没有答案