R中的多参数梯度下降

时间:2019-01-23 03:02:45

标签: r linear-regression gradient-descent

我是R用户,目前正在尝试使用“梯度下降”算法与多元线性回归进行比较。我在网上看到了一些代码,但是它们不适用于所有数据集。我用 以UCI自行车共享数据集(小时)为例

数据集可以在这里找到: https://archive.ics.uci.edu/ml/machine-learning-databases/00275/

将数据分为训练/测试集并创建矩阵:

data1 <- data[, c("season", "mnth", "hr", "holiday", "weekday", "workingday", "weathersit", "temp", "atemp", "hum", "windspeed", "cnt")]

# Split the data
trainingObs<-sample(nrow(data1),0.70*nrow(data1),replace=FALSE)

# Create the training dataset
trainingDS<-data1[trainingObs,]

# Create the test dataset
testDS<-data1[-trainingObs,]

x0 <- rep(1, nrow(trainingDS)) # column of 1's
x1 <- trainingDS[, c("season", "mnth", "hr", "holiday", "weekday", "workingday", "weathersit", "temp", "atemp", "hum", "windspeed")]

# create the x- matrix of explanatory variables
x <- as.matrix(cbind(x0,x1))

# create the y-matrix of dependent variables

y <- as.matrix(trainingDS$cnt)
m <- nrow(y)

solve(t(x)%*%x)%*%t(x)%*%y  

接着是梯度函数

gradientDesc <- function(x, y, learn_rate, conv_threshold, max_iter) {
  n <- nrow(x) 
  m <- runif(ncol(x), 0, 1) # m is a vector of dimension ncol(x), 1
  yhat <- x %*% m # since x already contains a constant, no need to add another one
  
  MSE <- sum((y - yhat) ^ 2) / n
  
  converged = F
  iterations = 0
  
  while(converged == F) {
    m <- m - learn_rate * ( 1/n * t(x) %*% (yhat - y))
    yhat <- x %*% m
    MSE_new <- sum((y - yhat) ^ 2) / n
    
    if( abs(MSE - MSE_new) <= conv_threshold) {
      converged = TRUE
    }
    iterations = iterations + 1
    MSE <- MSE_new
    
    if(iterations >= max_iter) break
  }
  return(list(converged = converged, 
              num_iterations = iterations, 
              MSE = MSE_new, 
              coefs = m) )
} 

ols <- solve(t(x)%*%x)%*%t(x)%*%y  

out <- gradientDesc(x,y, 0.005, 1e-7, 200000) 

data.frame(ols, out$coefs)  

它工作正常,并且在多元回归和梯度解决方案之间产生以下比较:

                 ols    out.coefs
x0           30.8003341   33.4473667
season       19.7839676   19.8020073
mnth         -0.1249776   -0.1290033
hr            7.4554424    7.4619508
holiday     -15.6022846  -15.8630012
weekday       1.8238997    1.7930636
workingday    5.0487553    5.0088699
weathersit   -2.2088254   -2.3389047
temp         85.6214524  141.1351024
atemp       235.5992391  173.1234342
hum        -226.7253991 -226.1559532
windspeed    33.5144866   30.1245570

它也按照与以前完全相同的命令处理虹膜数据集:

iris 
head(iris) 
data2 <-iris[,c("Sepal.Width", "Petal.Length","Petal.Width","Sepal.Length")]


# Split the data
trainingObs1<-sample(nrow(data2),0.70*nrow(data2),replace=FALSE)

# Create the training dataset
trainingDS1<-data2[trainingObs1,]

# Create the test dataset
testDS2<-data2[-trainingObs1,]

x0a <- rep(1, nrow(trainingDS1)) # column of 1's
x1a<-trainingDS1[, c("Sepal.Width", "Petal.Length","Petal.Width")]
z <- as.matrix(cbind(x0a,x1a))

y<-as.matrix(trainingDS1$Sepal.Length) 
m<-nrow(y)


solve(t(z)%*%z)%*%t(z)%*%y 

ols <- solve(t(z)%*%z)%*%t(z)%*%y  

out <- gradientDesc(z,y, 0.005, 1e-7, 200000) 

data.frame(ols, out$coefs)  

产生以下输出:

                   ols  out.coefs
x0a           1.7082712  1.3933410
Sepal.Width   0.6764848  0.7578847
Petal.Length  0.7225420  0.7571403
Petal.Width  -0.5436298 -0.6001406

但是,将其与mtcars数据集一起使用时:

mtcars<-mtcars 
head(mtcars) 
data3<-mtcars[,c("hp","wt","gear","cyl","mpg")] 
trainingObs2<-sample(nrow(data3),0.70*nrow(data3),replace=FALSE) 
trainingDS2<-data3[trainingObs2,] 
testDS3<-data3[-trainingObs2,] 
x0b <- rep(1, nrow(trainingDS2)) # column of 1's
x1b<-trainingDS2[, c("hp", "wt","gear","cyl")]
w <- as.matrix(cbind(x0b,x1b)) 
y<-as.matrix(trainingDS2$mpg)  
m<-nrow(y)
solve(t(w)%*%w)%*%t(w)%*%y  

ols <- solve(t(w)%*%w)%*%t(w)%*%y  
out <- gradientDesc(w,y, 0.005, 1e-7, 200000) 
data.frame(ols, out$coefs)  

它无法进行比较,从而产生以下错误:

> ols <- solve(t(w)%*%w)%*%t(w)%*%y  
> out <- gradientDesc(w,y, 0.005, 1e-7, 200000) 
Error in if (abs(MSE - MSE_new) <= conv_threshold) { : 
  missing value where TRUE/FALSE needed
> data.frame(ols, out$coefs)
Error in data.frame(ols, out$coefs) : 
  arguments imply differing number of rows: 5, 4

感谢您的帮助和指点。非常感谢您的宝贵时间。

0 个答案:

没有答案