测试恢复的张量流模型的一般方法

时间:2017-02-24 22:27:23

标签: python-2.7 tensorflow

我对Tensorflow非常非常全新,需要编写一个脚本来测试从检查点文件恢复的模型上的单个示例。

我想知道是否有一般方法为恢复的模型构建测试函数,而不知道模型的所有细节。

此外,在下面的代码的最后一部分中,这看起来像是朝着正确的方向前进吗?如果是这样,如何在不知道模型详细信息的情况下构建“y”?

import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
import numpy as np
from fuel.datasets.hdf5 import H5PYDataset

ckpt_path='ckt/mnist/mnist_2017_02_23_17_22_50/mnist_2017_02_23_17_22_50_5000.ckpt'

##############################
#### Initialize Variables ####
##############################

reader = pywrap_tensorflow.NewCheckpointReader(ckpt_path)
var_to_shape_map = reader.get_variable_to_shape_map()
var=[0]*len(var_to_shape_map)
i=0
for key in var_to_shape_map:
    var[i] = tf.Variable(reader.get_tensor(key), name=key)
    #print("tensor_name: ", key)
    #print(reader.get_tensor(key))
    i=i+1
initialize=tf.global_variables_initializer()

###############################
####### Restore Model #########
###############################

saver = tf.train.Saver()
sess = tf.Session()
saver.restore(sess, ckpt_path)

###############################
##### Get Example to Test #####
###############################

test_set = H5PYDataset('../CNN3D/data/bmnist.hdf5', which_sets=('test',))
handle = test_set.open()
for i in range(0,100):
    test_data = test_set.get_data(handle, slice(i, i+1))
    if test_data[1][0][0]==8:
        model_idx=i
test_data = test_set.get_data(handle, slice(model_idx,model_idx+1))
data = tf.Variable(np.asarray(test_data[0][0][0]), name='data')

###############################
######## Test Example #########
###############################

x = tf.placeholder(tf.float32,shape=[28,28])
y = ???
sess.run(initialize)
result=sess.run(y, feed_dict={x: data})
print result

1 个答案:

答案 0 :(得分:0)

Estimator类有一组方便的实用程序,如果你的模型包含在估算器中,那么从中加载和预测很容易。

总的来说,如果没有某种协调,这将很难。