如何重用经过训练的模型来执行分类 - Tensorflow

时间:2018-05-05 12:13:01

标签: tensorflow

我在Tensorflow上训练了一个CNN模型,我想重复使用来进行分类和测试。 这就是我目前正在做的事情:

def test(trained_model):
    # returns a iterator.get.next()
    x_test, y_test = inputs('test_set.tfrecords', batch_size=128, training_size=10000, shuffle=False, num_epochs=1)
    # get the output of the cnn
    predictions = tf.nn.softmax(AlexNet(x_test))
    with tf.name_scope('Accuracy'):
        # Accuracy
           acc = tf.equal(tf.argmax(predictions, 1), tf.argmax(y_test, 1))
           acc = tf.reduce_mean(tf.cast(acc, tf.float32))

    # Initializing the variables
    init = tf.global_variables_initializer()
    with tf.Session() as new_sess:
        saver = tf.train.import_meta_graph(trained_model)
        saver.restore(new_sess,tf.train.latest_checkpoint('./'))
        graph = tf.get_default_graph()
        cnt = 1
        try:
            while(True):
                new_sess.run(init)
                print(acc.eval(), cnt)
                cnt+=1

        except tf.errors.OutOfRangeError:
            print('Finished batch')

它似乎有效,但它与我发现的other answers不同,人们使用graph.get_tensor_by_name("y_:0")feed_dict我不明白。 谁能告诉我,我正在做的事情是对的,什么是正确的工作流程?

1 个答案:

答案 0 :(得分:1)

你所做的是正确的,没有“正确的工作流程”(tl;博士:他们在逻辑上是等同的。)

使用Saver保存模型时,Tensorflow会自动为您创建.meta.ckpt个文件,其中.meta包含图表定义(列表节点及其连接)和.ckpt文件包含模型参数。

tf.train.import_meta_graph在当前默认图表中加载.meta文件中保存的图表定义,restore()调用使用ckpt文件的权重集填充图表

显然,如果当前默认图表已经具有import_meta_graph尝试定义的相同定义,则跳过定义步骤。

这意味着如果您在导入元图之前已经定义了相同的图,则可以使用python变量(例如predictions)来引用图中的节点。

相反,如果您还没有定义图形,import_meta_graph将为您定义图形,但您不会准备好使用任何python变量。

因此,您必须从图中提取对所需节点的引用,并创建一个要使用的python变量(例如input = graph.get_tensor_by_name("logits:0")