节省时间在6嵌套for循环代码中

时间:2017-06-24 00:18:44

标签: r optimization time

我正在尝试为我的代码找到合适的参数,这是扩展卡尔曼滤波器的一个回溯。我有6个嵌套for循环,每个参数一个。目前,当每个参数有3个可能的值时,代码平均需要大约5分钟才能运行,但是当我增加参数数量时,我的时间会增加为n ^ 6。我有点担心。

有什么办法可以优化代码以节省更多时间吗?

PS - 只使用任何数据文件而不是给定的Reddy.csv(1180行数据) PPS - 最后我需要找到最小MSE的i,j,k,l,m,n值。

以下是代码:

start.time <- Sys.time()

library(invgamma)
w = read.csv("Reddy.csv")
q = ts(w[2])
num = length(q)

f = function(x){
  f1 = sqrt(x)
  return(f1)
}
h = function(x){
  h1 = x**3
  return(h1)
}


ae1 = seq(24,26,0.1)
ae2 = seq(24,26,0.1)

be1 = seq(0.1,2,0.1)
be2 = seq(0.1,2,0.1)

a = seq(1,3,0.1)
b = seq(0.1,2,0.1)

count = 0

MSE = matrix(nrow = length(ae1)*length(ae2)*length(be1)*length(be2)*length(a)*length(b), ncol =7)

for (i in ae1){
  for (j in ae2){
    for (k in be1){
      for (l in be2){
        for (m in a){
          for (n in b){
            d = rep(0,num)
            xt = rep(0,num)
            yt = rep(0,num)
            fx = rep(0,num)
            hx = rep(0,num)

            e = rinvgamma(num,i,k)
            g = rinvgamma(num,j,l)
            for(o in 2:num){
              fx[o] = f(xt[o-1])
              xt[o] = m*fx[o] + e[o-1]
              hx[o] = h(xt[o])
              yt[o]= n*hx[o] +g[o]
              d[o] = (yt[o] - q[o])**2
            }
            count <- count + 1
            MSE[count,1] = mean(d)
            MSE[count,2] = i
            MSE[count,3] = j
            MSE[count,4] = k
            MSE[count,5] = l
            MSE[count,6] = m
            MSE[count,7] = n
            t = rbind(mean(d),i,j,k,l,m,n)
            print(t)
          }
        }
      }
    }
  }
}

end.time <- Sys.time()
time.taken <- end.time - start.time
time.taken

m = which.min(MSE[,1])
MSE[m,]

1 个答案:

答案 0 :(得分:0)

通过矢量化一些最内部的循环计算,可以实现进一步的优化,

start.time <- Sys.time()

library(invgamma)
w = rnorm(1180)
q = ts(w)
num = length(q)

f = function(x){
    f1 = sqrt(x)
    return(f1)
}
h = function(x){
    h1 = x**3
    return(h1)
}


ae1 = seq(24,26,1)
ae2 = seq(24,26,1)

be1 = seq(0.1,2,0.7)
be2 = seq(0.1,2,0.7)

a = seq(1,3,1)
b = seq(0.1,2,0.7)

count = 0

MSE = matrix(nrow = length(ae1)*length(ae2)*length(be1)*length(be2)*length(a)*length(b), ncol =7)

for (i in ae1){
    for (j in ae2){
        for (k in be1){
            for (l in be2){
                for (m in a){
                    for (n in b){
                        d = rep(0,num)
                        xt = rep(0,num)
                        yt = rep(0,num)
                        fx = rep(0,num)
                        hx = rep(0,num)
                        e = rinvgamma(num,i,k)
                        g = rinvgamma(num,j,l)
                        for(o in 2:num){
                            fx[o] = f(xt[o-1])
                            xt[o] = m*fx[o] + e[o-1]
                        }
                        hx = h(xt)
                        yt = n*hx +g
                        d = (yt - q)**2

                        count <- count + 1
                        MSE[count,1] = mean(d)
                        MSE[count,2] = i
                        MSE[count,3] = j
                        MSE[count,4] = k
                        MSE[count,5] = l
                        MSE[count,6] = m
                        MSE[count,7] = n
                        ## t = rbind(mean(d),i,j,k,l,m,n)
                        ## print(t)
                    }
                }
            }
        }
    }
}

end.time <- Sys.time()
time.taken <- end.time - start.time
time.taken

m = which.min(MSE[,1])
MSE[m,]

在我的笔记本电脑上,这可以使代码快几倍。

编辑:

如果组合太多,并且您不想创建和修改大型矩阵,因为您只关心MSE的最小值,您只能记录最佳组合,如下所示:< / p>

MSEopt <- Inf

for (i in ae1){
    for (j in ae2){
        for (k in be1){
            for (l in be2){
                for (m in a){
                    for (n in b){
                        d = rep(0,num)
                        xt = rep(0,num)
                        yt = rep(0,num)
                        fx = rep(0,num)
                        hx = rep(0,num)
                        e = rinvgamma(num,i,k)
                        g = rinvgamma(num,j,l)
                        for(o in 2:num){
                            fx[o] = f(xt[o-1])
                            xt[o] = m*fx[o] + e[o-1]
                        }
                        hx = h(xt)
                        yt = n*hx +g
                        d = (yt - q)**2

                        if (mean(d) < MSEopt) {
                            ## print(MSEopt)
                            MSEopt <- mean(d)
                            best_combn <- list(i = i, j = j, k = k, l = l, m = m, n = n)
                        }
                    }
                }
            }
        }
    }
}