在R中绘制梯度下降的矢量图

时间:2016-04-14 09:03:20

标签: r plot machine-learning gradient-descent

我在R中编写了梯度下降算法,现在我试图"绘制"向量的路径。

我的轮廓图中有绘制点,但它不正确,因为没有人知道先发生了什么。

在我的算法中,我总是有一个先前的状态P =(Xi,Yi)和后来的状态L =(Xi + 1,Yi + 1),那么,我怎样才能在{{1}中绘制向量PL }或contour情节?

我只用persp得到了这个,其中红点是收敛:

contour http://oi65.tinypic.com/afh638.jpg

contour

相同

persp http://oi67.tinypic.com/2hppxj6.jpg

全部谢谢!

编辑:

图形可以分别被忽略:

persp

2 个答案:

答案 0 :(得分:0)

?persp中所述,使用persp绘制点是一种技巧。通过使用trans3d的力量,您可以成功地将点和线放在透视图上。

f<-function(u,v){
  u*u*exp(2*v)+4*v*v*exp(-2*u)-4*u*v*exp(v-u)
}

x = seq(-2, 2, by = 0.5)
y = seq(-2, 2, by = 0.5)
z <- scale(outer(x,y,f))

view <- persp(x, y, z, phi = 30, theta = 30, xlim=c(-2,2), ylim=c(-2,2),
      xlab = "X", ylab = "Y", zlab = "Z", scale = FALSE,
      main = "F(u,v)", col="yellow", ticktype = "detailed")

set.seed(2)
pts <- data.frame(x = sample(x, 3),
                  y = sample(y, 3),
                  z = sample(z, 3))

points(trans3d(x = pts$x, y = pts$y, z = pts$z, pmat = view), pch = 16)
lines(trans3d(x = pts$x, y = pts$y, z = pts$z, pmat = view))

enter image description here

答案 1 :(得分:0)

Himmelblau's function作为测试示例:

f <- function(x, y) { (x^2+y-11)^2 + (x+y^2-7)^2 }

其偏导数:

dx <- function(x,y) {4*x**3-4*x*y-42*x+4*x*y-14}
dy <- function(x,y) {4*y**3+2*x**2-26*y+4*x*y-22}

运行梯度下降:

# gradient descent parameters
num_iter <- 100
learning_rate <- 0.001
x_val <- 6
y_val <- 6

updates_x <- vector("numeric", length = num_iter)
updates_y <- vector("numeric", length = num_iter)
updates_z <- vector("numeric", length = num_iter)

# parameter updates
for (i in 1:num_iter) {

  dx_val = dx(x_val,y_val)
  dy_val = dy(x_val,y_val)

  x_val <- x_val-learning_rate*dx_val
  y_val <- y_val-learning_rate*dx_val
  z_val <- f(x_val, y_val)

  updates_x[i] <- x_val
  updates_y[i] <- y_val
  updates_z[i] <- z_val
}

绘图:

x <- seq(-6, 6, length = 100)
y <- x
z <- outer(x, y, f)

plt <- persp(x, y, z,
               theta = -50-log(i), phi = 20+log(i),
               expand = 0.5,
               col = "lightblue", border = 'lightblue',
               axes = FALSE, box = FALSE,
               ltheta = 60, shade = 0.90
  )

points(trans3d(updates_x[1:i], updates_y[1:i], updates_z[1:i],pmat = plt),
       col = c(rep('white', num_iter-1), 'blue'),
       pch = 16,
       cex = c(rep(0.5, num_iter-1), 1))

enter image description here