R中的NMI实施错误吗?

时间:2019-05-21 20:01:18

标签: r

#calculate NMI(c,t) c : cluster assignment , t : ground truth

NMI <- function(c,t){
n <- length(c) # = length(t)
r <- length(unique(c))
g <- length(unique(t))

N <- matrix(0,nrow = r , ncol = g)
for(i in 1:r){
    for (j in 1:g){
        N[i,j] = sum(t[c == i] == j)
    }
}

N_t <- colSums(N)
N_c <- rowSums(N)

B <- (1/n)*log(t( t( (n*N) / N_c ) / N_t))
W <- B*N
I <- sum(W,na.rm = T) 



H_c <- sum((1/n)*(N_c * log(N_c/n)) , na.rm = T)
H_t <- sum((1/n)*(N_t * log(N_t/n)) , na.rm = T)    

nmi <- I/sqrt(H_c * H_t)

return (nmi)
}

在某些群集基准测试here上运行,可以为我提供标准化互信息的值。但是,当我将其与从aricode库获得的NMI值进行比较时,我得到的NMI值通常在小数点后第二位有所不同。

如果有人能够指出此代码中可能出现的任何错误,我将不胜感激。

我正在使用综合案例包括一个测试案例:

library(aricode)
c <- c(1,1,2,2,2,3,3,3,3,4,4,4)
t <- c(1,2,2,2,3,4,3,3,3,4,4,2)
print(aricode::NMI(c , t))   #0.489574
print(NMI(c,t))              #0.5030771

1 个答案:

答案 0 :(得分:2)

这可能已经很晚了,但为了后代:

区别在于您和 aricode 包标准化索引的方式。您除以 sqrt()aricode 提供以下选项: function (c1, c2, variant = c("max", "min", "sqrt", "sum", "joint"))

因此,如果您选择 variant = sqrt,您应该会得到相同的答案。

NMI 包使用 sum。