如何有效地检查矩阵是否为二进制形式(例如,所有1或0)?

时间:2014-04-24 16:06:39

标签: r matrix

我有一个函数,它采用m×n大小(可能)的二进制矩阵作为输入,如果矩阵包含的数字不是0或1,或者是NA,我想返回错误处理。我怎样才能有效地检查这个?

例如,通过为10 x 10生成一些数据:

> n=10;m=10
> mat = round(matrix(runif(m*n), m, n))
> mat
        [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10]
 [1,]    0    1    0    1    1    0    1    0    1     0
 [2,]    0    0    0    0    0    0    0    0    0     1
 [3,]    1    1    0    1    1    0    0    1    1     0
 [4,]    1    1    1    1    0    1    0    0    1     1
 [5,]    1    1    1    0    0    1    1    1    0     1
 [6,]    1    0    1    0    0    0    0    1    0     0
 [7,]    0    0    0    1    0    1    1    1    1     0
 [8,]    0    0    0    1    0    1    1    1    1     1
 [9,]    0    0    1    1    0    1    1    1    1     1
[10,]    1    0    1    1    0    0    0    0    1     1

应始终返回矩阵是二进制的,但是可以通过以下方式之一进行更改:

> mat[1,1]=NA
> mat[1,1]=2

应该返回矩阵不是二进制的。

目前,我一直在使用我的功能:

for(i in 1:nrow(mat))
{
    for(j in 1:ncol(mat))
    {
      if(is.na(mat[i,j])|(!(mat[i,j] == 1 | mat[i,j] == 0)))
      {
        stop("Data must be only 0s, 1s")
      }
    }
}

但单独检查大型矩阵的每个值似乎非常缓慢且效率低下。有没有一种聪明,简单的方法可以做到这一点我错过了?

由于

5 个答案:

答案 0 :(得分:5)

以下是一些选项的时间安排(包括其他答案中建议的选项):

n=5000;m=5000
mat = round(matrix(runif(m*n), m, n))
> system.time(stopifnot(sum(mat==0) + sum(mat==1) == length(mat)))
   user  system elapsed 
   0.30    0.02    0.31 
> system.time(stopifnot(all(mat %in% c(0,1))))
   user  system elapsed 
   0.58    0.06    0.63 
> system.time(stopifnot(all(mat==0 | mat==1)))
   user  system elapsed 
   0.77    0.03    0.80 

它们都非常快,考虑到5000 x 5000矩阵!三者中最快的似乎是:

stopifnot(sum(mat==0) + sum(mat==1) == length(mat))

答案 1 :(得分:5)

我立刻想到了identical(mat,matrix(as.numeric(as.logical(mat),nr=nrow(mat)) ) )

这会将NA留给NA,因此,如果您想确定此类内容的存在,您只需要快速any(is.na(mat))或类似的测试。

编辑:计时赛

fun2 <- function(x) {
      all(x %in% 0:1)
}
fun1 <-function(x) {identical(as.vector(x),as.numeric(as.logical(x)))}

mfoo<-matrix(sample(0:10,1e6,rep=TRUE),1e3)
 microbenchmark(fun1(mfoo),fun2(mfoo),is.binary.sum2(mfoo),times=10)
Unit: milliseconds
                 expr       min        lq    median        uq
           fun1(mfoo)  2.286941  2.809926  2.835584  2.865518
           fun2(mfoo) 20.369075 20.894627 21.100528 21.226464
 is.binary.sum2(mfoo) 11.394503 12.418238 12.431922 12.458436
       max neval
  2.920253    10
 21.407777    10
 28.316492    10

反对not...事情:我必须投入try以避免破坏测试。

notfun <- function(mat) try(stopifnot(sum(mat==0) + sum(mat==1) == length(mat)))
 microbenchmark(fun1(mfoo),notfun(mfoo),is.binary.sum2(mfoo),times=10)
Error : sum(mat == 0) + sum(mat == 1) == length(mat) is not TRUE
##error repeated 10x for the 10 trials
Unit: milliseconds
                 expr       min        lq    median        uq
           fun1(mfoo)  4.870653  4.978414  5.057524  5.268344
         notfun(mfoo) 18.149273 18.685942 18.942518 19.241856
 is.binary.sum2(mfoo) 11.428713 12.145842 12.516165 12.605111
       max neval
  5.438111    10
 34.826230    10
 13.090465    10

我赢了! : - )

答案 2 :(得分:4)

我想添加基于sum的比较略微修改的版本,比@ JamesTrimble的版本更快。我希望我的所有假设都是正确的:

is.binary.sum2 <- function(x) {
  identical(sum(abs(x)) - sum(x == 1), 0)
}

这里的基准:

library(rbenchmark)

n=5000
m=5000
mat = round(matrix(runif(m*n), m, n))

is.binary.sum <- function(x) {
  sum(x == 0) + sum(x == 1) == length(x)
}

is.binary.sum2 <- function(x) {
  identical(sum(abs(x)) - sum(x == 1), 0)
}

is.binary.all <- function(x) {
  all(x == 0 | x == 1)
}

is.binary.in <- function(x) {
  all(x %in% c(0, 1))
}

benchmark(is.binary.sum(mat), is.binary.sum2(mat),
          is.binary.all(mat), is.binary.in(mat),
          order="relative", replications=10)
#                 test replications elapsed relative user.self sys.self user.child sys.child
#2 is.binary.sum2(mat)           10   4.635    1.000     3.872    0.744          0         0
#1  is.binary.sum(mat)           10   7.097    1.531     6.565    0.512          0         0
#4   is.binary.in(mat)           10  10.359    2.235     9.216    1.108          0         0
#3  is.binary.all(mat)           10  12.565    2.711    11.753    0.772          0         0

答案 3 :(得分:3)

非常有效(和可读)的方式可能是

all(mat %in% c(0,1))

然而,正如所指出的那样,与其他解决方案相比,它可能不是最快的。

但是,添加一些,如果效率是必须的(例如,你做了很多测试 通过使用integer矩阵给出很多增益double s有更多字节)并检查integer值。 这种增益也可以应用于其他解决方案。 以%in%进行的一些测试如下:

library(microbenchmark)
set.seed(1)

my.dim <- 1e04
n <- my.dim
m <- my.dim
mat <- round(matrix(runif(m*n), m, n))
int.mat <- as.integer(mat)

fun1 <- function(x) {
      all(x %in% c(0,1))
}
fun2 <- function(x) {
      all(x %in% 0:1)
}

## why?
storage.mode(0:1)
## [1] "integer"
storage.mode(c(0,1))
## [1] "double"
object.size(0:1)
## 48 bytes
object.size(c(0,1))
## 56 bytes
## and considering mat and int.mat
object.size(mat)
## 800000200 bytes
object.size(int.mat)
## 400000040 bytes

(res <- microbenchmark(fun1(mat), fun2(int.mat), times = 10, unit = "s"))
## Unit: seconds
##           expr     min      lq  median      uq     max neval
##      fun1(mat) 3.68843 3.69325 3.70433 3.72627 3.73041    10
##  fun2(int.mat) 1.28956 1.29157 1.32934 1.34370 1.35718    10

从3.70到1.32 那么糟糕:)

答案 4 :(得分:1)

注意,我更改了一些内容,因此它在octave中运行,但它应该与matlab非常相似。

生成矩阵:

n=5000;m=5000
mat=randi([0,1],n,m);

现在我们只是做一些简单的事情,我们知道1*2-1会使1等于1,而0等于-1。所以,abs使它完全相同。对于任何其他值,请说-1-1*2-1=-3,情况并非如此。然后我们减去1,我们应该留下一个只有零的矩阵。这可以使用any

在matlab / octave中轻松检查
any(any(abs(mat*2-1)-1));

检查速度:

mat=randi([0,1],n,m);
[t0 u0 s0]=cputime(); any(any(abs(mat+mat-1)-1)); [t1 u1 s1]=cputime(); [t1-t0 u1-u0 s1-s0]
ans =
 0.176772   0.127546   0.049226

按顺序totalusersystem时间。

0.18秒内相当不错,其中大部分都处于用户模式。对于10.000 * 10.000条目,它仍然不到一秒钟,在我的系统上以0.86秒计时。

哦,哎呀,我现在才发现它实际上是R,而不是matlab。我希望有人喜欢这种比较。

NaN / octave matlab可以轻松处理isnan(mat)值,如果您愿意,最终可以采用any(any(isnan(mat)))的形式。这包括NA值。仅处理NA值是isna(mat)