我正在尝试计算tensorflow
中的互信息,并将其与我的numpy
代码进行比较:
import numpy as np
import tensorflow as tf
import nibabel as nib
def nmi(vol1,vol2,bin):
m=vol1.shape[2]
n=vol2.shape[2]
if m>n:
d=m-n
slices=np.zeros([vol1.shape[0],vol2.shape[1],d+n])
slices[:,:,0:n]=vol2
vol2=slices
else:
d=n-m
slices=np.zeros([vol1.shape[0],vol2.shape[1],d+m])
slices[:,:,0:m]=vol1
vol1=slices
h,x_edges,y_edges=np.histogram2d(vol1.ravel(),vol2.ravel(),bins=bin)
pxy=h/np.sum(h)
px = np.sum(pxy, axis=1)
py = np.sum(pxy, axis=0) # marginal for y over x
px_py = px[:, None] * py[None, :] # Broadcast to multiply marginals
#现在我们可以使用pxy,px_py 2D数组进行计算 nzs = pxy> 0#只有非零pxy值对和
有贡献 return np.sum(pxy[nzs] * np.log(pxy[nzs] / px_py[nzs]))
现在要与之比较的tensorflow代码如下:
def nmi_tf(x,y,bin):
vol1=tf.placeholder(dtype=tf.float32,shape=
(x.shape),name='volume1')
vol2=tf.placeholder(dtype=tf.float32,shape=
(x.shape),name='volume2')
sess=tf.Session()
tf.global_variables_initializer()
value_range=[0,1]
x1=tf.histogram_fixed_width(tf.cast(vol1,dtype=tf.float32),
value_range,nbins=bin)
x2=tf.histogram_fixed_width(tf.cast(vol2,dtype=tf.float32),
value_range,nbins=bin)
pxy=tf.concat([x1,x2],0)
pxy=tf.cast(tf.histogram_fixed_width(pxy,value_range,nbins=bin),
dtype=tf.float32)
pxy=pxy/tf.reduce_sum(pxy)
px=tf.reduce_sum(tf.cast(x1,dtype=tf.float32))
py=tf.reduce_sum(tf.cast(x2,dtype=tf.float32))
p_x_y=tf.multiply(px,py)
nmi=tf.reduce_sum(pxy/p_x_y)
print(sess.run(nmi,feed_dict={vol1:x,vol2:y}))