如何加载由Google命名的预训练张量流模型?

时间:2016-10-11 02:44:23

标签: python tensorflow

我已经下载了一个名为inception_resnet_v2_2016_08_30.ckpt的张量流检查点模型。

我是否需要创建创建此检查点时使用的图形(包含所有变量)?

如何使用此模型?

2 个答案:

答案 0 :(得分:4)

首先,你们已经在内存中获得了网络架构。您可以从here

获取网络架构

一旦有了这个程序,请使用以下方法来使用该模型:

from inception_resnet_v2 import inception_resnet_v2, inception_resnet_v2_arg_scope

height = 299
width = 299
channels = 3

X = tf.placeholder(tf.float32, shape=[None, height, width, channels])
with slim.arg_scope(inception_resnet_v2_arg_scope()):
     logits, end_points = inception_resnet_v2(X, num_classes=1001,is_training=False)

这样你就拥有了内存中的所有网络,现在你可以使用tf.train.saver用checkpoint文件(ckpt)初始化网络:

saver = tf.train.Saver()
sess = tf.Session()
saver.restore(sess, "/home/pramod/Downloads/inception_resnet_v2_2016_08_30.ckpt")

如果你想要进行瓶子特征提取,它的简单就好比你想要从最后一层获得功能,那么只需要声明predictions = end_points["Logits"]如果你想获得其他中间层,你可以从上面的程序inception_resnet_v2.py

中获取这些名称

之后,您可以致电:output = sess.run(predictions, feed_dict={X:batch_images})

答案 1 :(得分:2)

  

我是否需要创建创建此检查点时使用的图形(包含所有变量)?

不,你没有。

关于如何使用检查点文件(cpkt文件)

1.本文(TensorFlow-Slim image classification library)告诉您如何从头开始训练您的模型

2.以下是google blog

的示例代码
import numpy as np
import os
import tensorflow as tf
import urllib2

from datasets import imagenet
from nets import inception
from preprocessing import inception_preprocessing

slim = tf.contrib.slim

batch_size = 3
image_size = inception.inception_v3.default_image_size

checkpoints_dir = '/root/code/model'
checkpoints_filename = 'inception_resnet_v2_2016_08_30.ckpt'
model_name = 'InceptionResnetV2'
sess = tf.InteractiveSession()
graph = tf.Graph()
graph.as_default()

def classify_from_url(url):
    image_string = urllib2.urlopen(url).read()
    image = tf.image.decode_jpeg(image_string, channels=3)
    processed_image = inception_preprocessing.preprocess_image(image,     image_size, image_size, is_training=False)
processed_images  = tf.expand_dims(processed_image, 0)

# Create the model, use the default arg scope to configure the batch norm parameters.
with slim.arg_scope(inception.inception_resnet_v2_arg_scope()):
    logits, _ = inception.inception_resnet_v2(processed_images, num_classes=1001, is_training=False)
probabilities = tf.nn.softmax(logits)

init_fn = slim.assign_from_checkpoint_fn(
    os.path.join(checkpoints_dir, checkpoints_filename),
    slim.get_model_variables(model_name))

init_fn(sess)
np_image, probabilities = sess.run([image, probabilities])
probabilities = probabilities[0, 0:]
sorted_inds = [i[0] for i in sorted(enumerate(-probabilities), key=lambda x:x[1])]

plt.figure()
plt.imshow(np_image.astype(np.uint8))
plt.axis('off')
plt.show()

names = imagenet.create_readable_names_for_imagenet_labels()
for i in range(5):
    index = sorted_inds[i]
    print('Probability %0.2f%% => [%s]' % (probabilities[index], names[index]))