我在Tensorflow上训练了一个CNN模型,我想重复使用来进行分类和测试。 这就是我目前正在做的事情:
def test(trained_model):
# returns a iterator.get.next()
x_test, y_test = inputs('test_set.tfrecords', batch_size=128, training_size=10000, shuffle=False, num_epochs=1)
# get the output of the cnn
predictions = tf.nn.softmax(AlexNet(x_test))
with tf.name_scope('Accuracy'):
# Accuracy
acc = tf.equal(tf.argmax(predictions, 1), tf.argmax(y_test, 1))
acc = tf.reduce_mean(tf.cast(acc, tf.float32))
# Initializing the variables
init = tf.global_variables_initializer()
with tf.Session() as new_sess:
saver = tf.train.import_meta_graph(trained_model)
saver.restore(new_sess,tf.train.latest_checkpoint('./'))
graph = tf.get_default_graph()
cnt = 1
try:
while(True):
new_sess.run(init)
print(acc.eval(), cnt)
cnt+=1
except tf.errors.OutOfRangeError:
print('Finished batch')
它似乎有效,但它与我发现的other answers不同,人们使用graph.get_tensor_by_name("y_:0")
和feed_dict
我不明白。
谁能告诉我,我正在做的事情是对的,什么是正确的工作流程?
答案 0 :(得分:1)
你所做的是正确的,没有“正确的工作流程”(tl;博士:他们在逻辑上是等同的。)
使用Saver
保存模型时,Tensorflow会自动为您创建.meta
和.ckpt
个文件,其中.meta
包含图表定义(列表节点及其连接)和.ckpt
文件包含模型参数。
tf.train.import_meta_graph
在当前默认图表中加载.meta
文件中保存的图表定义,restore()
调用使用ckpt
文件的权重集填充图表
显然,如果当前默认图表已经具有import_meta_graph
尝试定义的相同定义,则跳过定义步骤。
这意味着如果您在导入元图之前已经定义了相同的图,则可以使用python变量(例如predictions
)来引用图中的节点。
相反,如果您还没有定义图形,import_meta_graph
将为您定义图形,但您不会准备好使用任何python变量。
因此,您必须从图中提取对所需节点的引用,并创建一个要使用的python变量(例如input = graph.get_tensor_by_name("logits:0")
)