我有一个函数,它采用m×n大小(可能)的二进制矩阵作为输入,如果矩阵包含的数字不是0或1,或者是NA,我想返回错误处理。我怎样才能有效地检查这个?
例如,通过为10 x 10生成一些数据:
> n=10;m=10
> mat = round(matrix(runif(m*n), m, n))
> mat
[,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10]
[1,] 0 1 0 1 1 0 1 0 1 0
[2,] 0 0 0 0 0 0 0 0 0 1
[3,] 1 1 0 1 1 0 0 1 1 0
[4,] 1 1 1 1 0 1 0 0 1 1
[5,] 1 1 1 0 0 1 1 1 0 1
[6,] 1 0 1 0 0 0 0 1 0 0
[7,] 0 0 0 1 0 1 1 1 1 0
[8,] 0 0 0 1 0 1 1 1 1 1
[9,] 0 0 1 1 0 1 1 1 1 1
[10,] 1 0 1 1 0 0 0 0 1 1
应始终返回矩阵是二进制的,但是可以通过以下方式之一进行更改:
> mat[1,1]=NA
> mat[1,1]=2
应该返回矩阵不是二进制的。
目前,我一直在使用我的功能:
for(i in 1:nrow(mat))
{
for(j in 1:ncol(mat))
{
if(is.na(mat[i,j])|(!(mat[i,j] == 1 | mat[i,j] == 0)))
{
stop("Data must be only 0s, 1s")
}
}
}
但单独检查大型矩阵的每个值似乎非常缓慢且效率低下。有没有一种聪明,简单的方法可以做到这一点我错过了?
由于
答案 0 :(得分:5)
以下是一些选项的时间安排(包括其他答案中建议的选项):
n=5000;m=5000
mat = round(matrix(runif(m*n), m, n))
> system.time(stopifnot(sum(mat==0) + sum(mat==1) == length(mat)))
user system elapsed
0.30 0.02 0.31
> system.time(stopifnot(all(mat %in% c(0,1))))
user system elapsed
0.58 0.06 0.63
> system.time(stopifnot(all(mat==0 | mat==1)))
user system elapsed
0.77 0.03 0.80
它们都非常快,考虑到5000 x 5000矩阵!三者中最快的似乎是:
stopifnot(sum(mat==0) + sum(mat==1) == length(mat))
答案 1 :(得分:5)
我立刻想到了identical(mat,matrix(as.numeric(as.logical(mat),nr=nrow(mat)) ) )
这会将NA
留给NA
,因此,如果您想确定此类内容的存在,您只需要快速any(is.na(mat))
或类似的测试。
编辑:计时赛
fun2 <- function(x) {
all(x %in% 0:1)
}
fun1 <-function(x) {identical(as.vector(x),as.numeric(as.logical(x)))}
mfoo<-matrix(sample(0:10,1e6,rep=TRUE),1e3)
microbenchmark(fun1(mfoo),fun2(mfoo),is.binary.sum2(mfoo),times=10)
Unit: milliseconds
expr min lq median uq
fun1(mfoo) 2.286941 2.809926 2.835584 2.865518
fun2(mfoo) 20.369075 20.894627 21.100528 21.226464
is.binary.sum2(mfoo) 11.394503 12.418238 12.431922 12.458436
max neval
2.920253 10
21.407777 10
28.316492 10
反对not...
事情:我必须投入try
以避免破坏测试。
notfun <- function(mat) try(stopifnot(sum(mat==0) + sum(mat==1) == length(mat)))
microbenchmark(fun1(mfoo),notfun(mfoo),is.binary.sum2(mfoo),times=10)
Error : sum(mat == 0) + sum(mat == 1) == length(mat) is not TRUE
##error repeated 10x for the 10 trials
Unit: milliseconds
expr min lq median uq
fun1(mfoo) 4.870653 4.978414 5.057524 5.268344
notfun(mfoo) 18.149273 18.685942 18.942518 19.241856
is.binary.sum2(mfoo) 11.428713 12.145842 12.516165 12.605111
max neval
5.438111 10
34.826230 10
13.090465 10
我赢了! : - )
答案 2 :(得分:4)
我想添加基于sum
的比较略微修改的版本,比@ JamesTrimble的版本更快。我希望我的所有假设都是正确的:
is.binary.sum2 <- function(x) {
identical(sum(abs(x)) - sum(x == 1), 0)
}
这里的基准:
library(rbenchmark)
n=5000
m=5000
mat = round(matrix(runif(m*n), m, n))
is.binary.sum <- function(x) {
sum(x == 0) + sum(x == 1) == length(x)
}
is.binary.sum2 <- function(x) {
identical(sum(abs(x)) - sum(x == 1), 0)
}
is.binary.all <- function(x) {
all(x == 0 | x == 1)
}
is.binary.in <- function(x) {
all(x %in% c(0, 1))
}
benchmark(is.binary.sum(mat), is.binary.sum2(mat),
is.binary.all(mat), is.binary.in(mat),
order="relative", replications=10)
# test replications elapsed relative user.self sys.self user.child sys.child
#2 is.binary.sum2(mat) 10 4.635 1.000 3.872 0.744 0 0
#1 is.binary.sum(mat) 10 7.097 1.531 6.565 0.512 0 0
#4 is.binary.in(mat) 10 10.359 2.235 9.216 1.108 0 0
#3 is.binary.all(mat) 10 12.565 2.711 11.753 0.772 0 0
答案 3 :(得分:3)
非常有效(和可读)的方式可能是
all(mat %in% c(0,1))
然而,正如所指出的那样,与其他解决方案相比,它可能不是最快的。
但是,添加一些,如果效率是必须的(例如,你做了很多测试
通过使用integer
矩阵给出很多增益
(double
s有更多字节)并检查integer
值。
这种增益也可以应用于其他解决方案。
以%in%
进行的一些测试如下:
library(microbenchmark)
set.seed(1)
my.dim <- 1e04
n <- my.dim
m <- my.dim
mat <- round(matrix(runif(m*n), m, n))
int.mat <- as.integer(mat)
fun1 <- function(x) {
all(x %in% c(0,1))
}
fun2 <- function(x) {
all(x %in% 0:1)
}
## why?
storage.mode(0:1)
## [1] "integer"
storage.mode(c(0,1))
## [1] "double"
object.size(0:1)
## 48 bytes
object.size(c(0,1))
## 56 bytes
## and considering mat and int.mat
object.size(mat)
## 800000200 bytes
object.size(int.mat)
## 400000040 bytes
(res <- microbenchmark(fun1(mat), fun2(int.mat), times = 10, unit = "s"))
## Unit: seconds
## expr min lq median uq max neval
## fun1(mat) 3.68843 3.69325 3.70433 3.72627 3.73041 10
## fun2(int.mat) 1.28956 1.29157 1.32934 1.34370 1.35718 10
从3.70到1.32 不那么糟糕:)
答案 4 :(得分:1)
注意,我更改了一些内容,因此它在octave
中运行,但它应该与matlab
非常相似。
生成矩阵:
n=5000;m=5000
mat=randi([0,1],n,m);
现在我们只是做一些简单的事情,我们知道1*2-1
会使1
等于1
,而0
等于-1
。所以,abs
使它完全相同。对于任何其他值,请说-1
,-1*2-1=-3
,情况并非如此。然后我们减去1
,我们应该留下一个只有零的矩阵。这可以使用any
:
any(any(abs(mat*2-1)-1));
检查速度:
mat=randi([0,1],n,m);
[t0 u0 s0]=cputime(); any(any(abs(mat+mat-1)-1)); [t1 u1 s1]=cputime(); [t1-t0 u1-u0 s1-s0]
ans =
0.176772 0.127546 0.049226
按顺序total
,user
和system
时间。
在0.18
秒内相当不错,其中大部分都处于用户模式。对于10.000 * 10.000
条目,它仍然不到一秒钟,在我的系统上以0.86
秒计时。
哦,哎呀,我现在才发现它实际上是R
,而不是matlab
。我希望有人喜欢这种比较。
NaN
/ octave
matlab
可以轻松处理isnan(mat)
值,如果您愿意,最终可以采用any(any(isnan(mat)))
的形式。这包括NA
值。仅处理NA
值是isna(mat)
。