Tensorflow:下载并运行预训练的VGG或ResNet模型

时间:2019-02-08 14:46:18

标签: python tensorflow

让我们从头开始。到目前为止,我自己已经在Tensorflow中创建和培训了小型网络。在训练期间,我保存了模型并在目录中获取以下文件:

model.ckpt.meta
model.ckpt.index
model.ckpt.data-00000-of-00001

稍后,我加载保存在network_dir中的模型以进行一些分类并提取模型的可训练变量。

saver = tf.train.import_meta_graph(network_dir + ".meta")
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="NETWORK")

现在,我想使用更大的经过预训练的模型,例如VGG16或ResNet,并希望使用我的代码来做到这一点。我想加载像我自己的网络这样的预训练模型,如上所示。

在此站点上,我发现了许多预先训练的模型:

https://github.com/tensorflow/models/tree/master/research/slim#pre-trained-models

我下载了VGG16检查点,并意识到这些只是经过训练的参数。

我想知道如何或在哪里可以获取这些预训练网络的保存的模型或图形结构?例如,如何使用没有model.ckpt.metamodel.ckpt.indexmodel.ckpt.data-00000-of-00001文件的VGG16检查点?

1 个答案:

答案 0 :(得分:0)

权重链接旁边,有指向定义模型的代码的链接。例如,对于VGG16:Code。使用代码创建模型并从检查点恢复变量:

import tensorflow as tf

slim = tf.contrib.slim

image = ...  # Define your input somehow, e.g with placeholder
logits, _ = vgg.vgg_16(image)
predictions = tf.argmax(logits, 1)
variables_to_restore = slim.get_variables_to_restore()

saver = tf.train.Saver(variables_to_restore)
with tf.Session() as sess:
    saver.restore(sess, "/path/to/model.ckpt")

因此,vgg.py中包含的代码将为您创建所有变量。使用tf-slim帮助器,您可以获得列表。然后,只需按照通常的步骤进行即可。上有a similar question