使用循环顺序更新data.table列

时间:2017-01-01 16:33:45

标签: r data.table

我使用data.table实现了一个简单的动态编程示例here,希望它能像矢量化代码一样快。

library(data.table)
B=100; M=50; alpha=0.5; beta=0.9;
n = B + M + 1
m = M + 1
u <- function(c)c^alpha
dt <- data.table(s = 0:(B+M))[, .(a = 0:min(s, M)), s] # State Space and corresponging Action Space
dt[, u := (s-a)^alpha,]                                # rewards r(s, a)
dt <- dt[, .(s_next = a:(a+B), u = u), .(s, a)]        # all possible (s') for each (s, a)
dt[, p := 1/(B+1), s]                                  # transition probs

#          s  a s_next  u        p
#     1:   0  0      0  0 0.009901
#     2:   0  0      1  0 0.009901
#     3:   0  0      2  0 0.009901
#     4:   0  0      3  0 0.009901
#     5:   0  0      4  0 0.009901
#    ---                          
#649022: 150 50    146 10 0.009901
#649023: 150 50    147 10 0.009901
#649024: 150 50    148 10 0.009901
#649025: 150 50    149 10 0.009901
#649026: 150 50    150 10 0.009901

为我的问题提供一些内容:以sa为条件,ss_next)的未来值将实现为{{1}之一每个都有概率a:(a+10)p=1/(B + 1)列为每个组合u提供u(s, a)

  • 根据(s, a)(贝尔曼方程式)给出每个唯一状态V的初始值n by 1(始终s向量),V更新。
  • 最大化是wrt V[s]=max(u(s, a)) + beta* sum(p*v(s_next)),因此,a在下面的迭代中。

实际上效率很高vectorized solution。我认为[, `:=`(v = max(v), i = s_next[which.max(v)]), by = .(s)]解决方案的性能与矢量化方法相当。

我知道主要罪魁祸首是data.table。唉,我不知道如何解决它。

dt[, v := V[s_next + 1]]

令我沮丧的是,# Iteration starts here system.time({ V <- rep(0, n) # initial guess for Value function i <- 1 tol <- 1 while(tol > 0.0001){ dt[, v := V[s_next + 1]] dt[, v := u + beta * sum(p*v), by = .(s, a) ][, `:=`(v = max(v), i = s_next[which.max(v)]), by = .(s)] # Iteration dt1 <- dt[, .(v[1L], i[1L]), by = s] Vnew <- dt1$V1 sig <- dt1$V2 tol <- max(abs(V - Vnew)) V <- Vnew i <- i + 1 } }) # user system elapsed # 5.81 0.40 6.25 解决方案甚至比以下高度非矢量化的解决方案更慢。作为一个草率的data.table-user,我必须缺少一些data.table功能。有没有办法改进,或者data.table不适合这些类型的计算?

data.table

1 个答案:

答案 0 :(得分:2)

这是我怎么做的......

DT = CJ(s = seq_len(n)-1L, a = seq_len(m)-1L, s_next = seq_len(n)-1L)
DT[ , p := 0]
#p is 0 unless this is true
DT[between(s_next, a, a + B), p := 1/(B+1)]
#may as well subset to eliminate irrelevant states
DT = DT[p>0 & s>=a]
DT[ , util := u(s - a)]

#don't technically need by, but just to be careful
DT[ , V0 := rep(0, n), by = .(a, s_next)]

while(TRUE) {
  #for each s, maximize given past value;
  #  within each s, have to sum over s_nexts,
  #  to do so, sum by a
  DT[ , V1 := max(.SD[ , util[1L] + beta*sum(V0*p), by = a],
               na.rm = TRUE), by = s]
  if (DT[ , max(abs(V0 - V1))] < 1e-4) break
  DT[ , V0 := V1]
}

在我的机器上这需要大约15秒(所以不好)......但也许这会给你一些想法。例如,这个data.table太大了,因为n最终只有V个唯一值。