#calculate NMI(c,t) c : cluster assignment , t : ground truth
NMI <- function(c,t){
n <- length(c) # = length(t)
r <- length(unique(c))
g <- length(unique(t))
N <- matrix(0,nrow = r , ncol = g)
for(i in 1:r){
for (j in 1:g){
N[i,j] = sum(t[c == i] == j)
}
}
N_t <- colSums(N)
N_c <- rowSums(N)
B <- (1/n)*log(t( t( (n*N) / N_c ) / N_t))
W <- B*N
I <- sum(W,na.rm = T)
H_c <- sum((1/n)*(N_c * log(N_c/n)) , na.rm = T)
H_t <- sum((1/n)*(N_t * log(N_t/n)) , na.rm = T)
nmi <- I/sqrt(H_c * H_t)
return (nmi)
}
在某些群集基准测试here上运行,可以为我提供标准化互信息的值。但是,当我将其与从aricode库获得的NMI值进行比较时,我得到的NMI值通常在小数点后第二位有所不同。
如果有人能够指出此代码中可能出现的任何错误,我将不胜感激。
我正在使用综合案例包括一个测试案例:
library(aricode)
c <- c(1,1,2,2,2,3,3,3,3,4,4,4)
t <- c(1,2,2,2,3,4,3,3,3,4,4,2)
print(aricode::NMI(c , t)) #0.489574
print(NMI(c,t)) #0.5030771
答案 0 :(得分:2)
这可能已经很晚了,但为了后代:
区别在于您和 aricode
包标准化索引的方式。您除以 sqrt()
而 aricode
提供以下选项:
function (c1, c2, variant = c("max", "min", "sqrt", "sum", "joint"))
因此,如果您选择 variant = sqrt
,您应该会得到相同的答案。
NMI
包使用 sum。