我对Tensorflow非常非常全新,需要编写一个脚本来测试从检查点文件恢复的模型上的单个示例。
我想知道是否有一般方法为恢复的模型构建测试函数,而不知道模型的所有细节。
此外,在下面的代码的最后一部分中,这看起来像是朝着正确的方向前进吗?如果是这样,如何在不知道模型详细信息的情况下构建“y”?
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
import numpy as np
from fuel.datasets.hdf5 import H5PYDataset
ckpt_path='ckt/mnist/mnist_2017_02_23_17_22_50/mnist_2017_02_23_17_22_50_5000.ckpt'
##############################
#### Initialize Variables ####
##############################
reader = pywrap_tensorflow.NewCheckpointReader(ckpt_path)
var_to_shape_map = reader.get_variable_to_shape_map()
var=[0]*len(var_to_shape_map)
i=0
for key in var_to_shape_map:
var[i] = tf.Variable(reader.get_tensor(key), name=key)
#print("tensor_name: ", key)
#print(reader.get_tensor(key))
i=i+1
initialize=tf.global_variables_initializer()
###############################
####### Restore Model #########
###############################
saver = tf.train.Saver()
sess = tf.Session()
saver.restore(sess, ckpt_path)
###############################
##### Get Example to Test #####
###############################
test_set = H5PYDataset('../CNN3D/data/bmnist.hdf5', which_sets=('test',))
handle = test_set.open()
for i in range(0,100):
test_data = test_set.get_data(handle, slice(i, i+1))
if test_data[1][0][0]==8:
model_idx=i
test_data = test_set.get_data(handle, slice(model_idx,model_idx+1))
data = tf.Variable(np.asarray(test_data[0][0][0]), name='data')
###############################
######## Test Example #########
###############################
x = tf.placeholder(tf.float32,shape=[28,28])
y = ???
sess.run(initialize)
result=sess.run(y, feed_dict={x: data})
print result