比较沿任意维度重复的m维数组和(m-1)维数组

时间:2015-02-09 18:52:56

标签: arrays r

我已经实现了一个基于多维数组的计算,它取代了一些循环代码。我在这个过程中做了一些事情,我认为可以做得更好 - 但我不确定如何。

其中之一是将得到的3d数组与沿第三维重复的2d数组进行比较。

items12 = c(1,2,3,4,5,6)
items3 = c(1,2,3)

m2d = outer(items12, items12, "-")
m3d = outer(items3, m2d, "*")

经过一些操作后,我想比较m2d和m3d,其中m2d在第三个暗淡处重复。我知道有两个选择,看起来都不那么优雅,如果有更好的方法,我很好奇。

实例化重复的3d数组。记忆沉重但速度很快。

m2d.z.3d = outer(
  m2d, 
  rep(1, length(items3)), "*"
)

m3d - m2d.z.3d

循环。轻但慢。

apply(m3d, 3, function(x) {
    x - m2d
})

有什么建议吗?你会选择哪个?

更新 澄清任意索引要求的例子。

items12 = c(1,2,3)
items3 = c(1,2)

m2d = outer(items12, items12, "-")
m3d = outer(m2d,items3, "*")

m3d - (m3d - items.3)

# items.3 wrapped along rows
, , 1

     [,1] [,2] [,3]
[1,]    1    2    3
[2,]    1    2    3
[3,]    1    2    3

, , 2

     [,1] [,2] [,3]
[1,]    1    2    3
[2,]    1    2    3
[3,]    1    2    3

m3d.yx = aperm(m3d, c(2,1,3))
aperm(m3d.yx - (m3d.yx - c(items.3)), c(2,1,3)) 

#items.3 wrapped around columns
, , 1

     [,1] [,2] [,3]
[1,]    1    1    1
[2,]    2    2    2
[3,]    3    3    3

, , 2

     [,1] [,2] [,3]
[1,]    1    1    1
[2,]    2    2    2
[3,]    3    3    3

更新

在这种情况下精子的一些基准。

items.3 = rep(c(1,2,3), n)
items.2 = rep(c(1,2), n)

m2d = outer(items.3, items.3, "-")
m3d = outer(m2d, items.2, "*")

funRecycle = function() # items.3 wraps around the columns (index 1, then 2, then 3 etc.)
  m3d - (m3d - c(items.3)) 
funAperm = function() { # temporarily interchange index 1 and 2 to apply along desired index
  m3d.yx = aperm(m3d, c(2,1,3))
  aperm(m3d.yx - (m3d.yx - c(items.3)), c(2,1,3)) 
}
funOuter = function() { # assign the 3d matrix
  m2d.z.3d = outer(
    m2d, 
    rep(1, length(items.2)), "*"
  )
  m3d - m2d.z.3d
}
funArray = function() { # assign the 3d matrix with array
  m2d.z.3d = array(m2d, dim=c(dim(m2d)[1:2], length(items.2)))
  m3d - m2d.z.3d
}
funSweep <- function() sweep(m3d, c(1, 2), m2d, "-")

n = 1

Unit: microseconds
         expr    min      lq     mean  median      uq    max neval   cld
 funRecycle()  1.110  1.3875  1.65388  1.6650  1.9420  2.775   100 a    
   funAperm() 17.200 19.1420 21.23113 20.2520 20.9455 69.077   100    d 
   funOuter() 14.426 15.8130 17.58316 17.2005 18.1710 35.232   100   c  
   funArray()  2.774  3.3300  3.95079  3.8840  4.1610 14.148   100  b   
   funSweep() 31.903 32.7360 34.84129 33.5680 34.4000 62.141   100     e

N = 100

Unit: milliseconds
         expr       min        lq      mean    median        uq       max
 funRecycle()  28.51351  32.35671  37.13257  33.98931  39.94408  85.94085
   funAperm() 232.69297 276.07494 344.70083 352.40273 395.50492 569.54978
   funOuter()  35.25947  43.98674  53.06895  49.72790  55.93677  95.38608
   funArray()  96.78482 110.10501 119.68267 116.50378 120.70943 172.53973
   funSweep() 150.88675 168.90293 193.06270 178.11013 216.79349 291.23719

我对结果感到惊讶,不知何故,在大n中,将所有内容乘以1与外部变得比简单地使用array()复制数组更快。 (在大n的外部()看起来它可能比循环方法更快。)

如果我们必须在不同的索引(funAperm)之间进行比较,那么在所有情况下构建带外部的数组会更快。

除了精子之外还有任何建议可以跨任意索引进行比较吗?

1 个答案:

答案 0 :(得分:1)

假设你的意思是(我假设这是因为m3d - m2d.z.3d不起作用):

m3d = outer(m2d, items3, "*") # note how I switched the arguments

然后这个有效:

m3d - c(m2d)

证明:

all.equal(m3d - c(m2d), m3d - m2d.z.3d)
# [1] TRUE

这里我们只是利用矢量回收,因为我们想要沿着最后一个维度重复。我们需要做c()去除尺寸,否则R会抱怨数组不合理(虽然它们实际上是我们想要的特定意义)。

基于对R源代码(src/main/arithmetic.c:real_binary())的敷衍回顾,看起来矢量回收不会复制回收的矢量,因此这应该既快速又节省内存。

如果我们想沿任意维度执行此操作,我们必须使用aperm重新调整所有维度的数组,以使相关维度最后,然后将结果重新洗回原始维度顺序。这会增加一些开销。

关于选择什么方法,如果你没有内存不足,请使用快速方法(即避免循环支持完全向量化操作)。

此外,还有items12 <- seq(100)items3 <- seq(50)的一些基准:

funOuter <- function() {
  m2d.z.3d = outer(
    m2d, 
    rep(1, length(items3)), "*"
  )
  m3d - m2d.z.3d
}
funRecycle <- function() m3d - c(m2d)
funLoop <- function() apply(m3d, 3, "-", m2d)    # this does not appear correct because `apply` doesn't reconstruct dimensions like `sapply`
funSweep <- function() sweep(m3d, c(1, 2), m2d)  # this is the same type of thing but works properly

library(microbenchmark)
microbenchmark(funOuter(), funRecycle(), funLoop(), funSweep())

产地:

Unit: milliseconds
         expr       min        lq      mean    median
   funOuter()  2.297287  2.673768  3.232277  2.835404
 funRecycle()  1.327101  1.485082  2.093252  1.599543
    funLoop() 22.579010 24.586667 27.211804 26.840069
   funSweep() 11.251656 12.012664 13.516147 13.736908       

并检查结果:

all.equal(funOuter(), funRecycle())
# [1] TRUE
all.equal(funOuter(), funSweep())
# [1] TRUE
all.equal(funOuter(), funLoop())
# Nope, not equal