尝试使用架构和检查点inception_resnet_v2_2016_08_30.ckpt运行Inceptionv2 Tensorflow模型。我的代码用于预测给定图像的每种分类的概率。
我尝试根据很棒的博客here使用类构造tensorflow代码。但是我们有错误:
NotFoundError (see above for traceback): Tensor name "prediction/InceptionResnetV2/AuxLogits/Conv2d_1b_1x1/BatchNorm/beta"not found in checkpoint files inception_resnet_v2_2016_08_30.ckpt.
我的错误代码如下。
from inception_resnet_v2 import *
import functools
import inception_preprocessing
import matplotlib.pyplot as plt
import os
import numpy as np
import tensorflow as tf
from scipy.misc import imread
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
def doublewrap(function):
"""
A decorator decorator, allowing to use the decorator to be used without
parentheses if no arguments are provided. All arguments must be optional.
"""
@functools.wraps(function)
def decorator(*args, **kwargs):
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
return function(args[0])
else:
return lambda wrapee: function(wrapee, args, *kwargs)
return decorator
@doublewrap
def define_scope(function, scope=None, args, *kwargs):
"""
A decorator for functions that define TensorFlow operations. The wrapped
function will only be executed once. Subsequent calls to it will directly
return the result so that operations are added to the graph only once.
The operations added by the function live within a tf.variable_scope(). If
this decorator is used with arguments, they will be forwarded to the
variable scope. The scope name defaults to the name of the wrapped
function.
"""
attribute = '_cache_' + function.__name__
name = scope or function.__name__
@property
@functools.wraps(function)
def decorator(self):
if not hasattr(self, attribute):
with tf.variable_scope(name, args, *kwargs):
setattr(self, attribute, function(self))
return getattr(self, attribute)
return decorator
class Inception(object):
def __init__(self,
image):
self.image = image
self.process_data # call function process_data
self.prediction
@define_scope
def process_data(self):
image_size = inception_resnet_v2.default_image_size
image = inception_preprocessing.preprocess_image(self.image, image_size, image_size, is_training=False, )
image1 = tf.expand_dims(image, 0)
return image1
@define_scope
def prediction(self):
'''Creates the Inception Resnet V2 model.'''
arg_scope = inception_resnet_v2_arg_scope()
with tf.contrib.slim.arg_scope(arg_scope):
logits, end_points = inception_resnet_v2(self.process_data, is_training=False)
probabilities = tf.nn.softmax(logits)
return probabilities
def main():
tf.reset_default_graph()
image = tf.placeholder(tf.float32, [None, None, 3])
model = Inception(image)
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess,
'inception_resnet_v2_2016_08_30.ckpt')
probabilities = sess.run(model.prediction, feed_dict={image: data})
print(probabilities)
if _name_ == '__main__':
data = imread('ILSVRC2012_test_00000003 .JPEG', mode='RGB').astype(np.float)
main()
但是,如果我们不使用上述类构造代码,那么我们就可以成功运行。 以下是无错误运行的代码。
from inception_resnet_v2 import *
import inception_preprocessing
import os
import numpy as np
import tensorflow as tf
from scipy.misc import imread
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
slim = tf.contrib.slim
tf.reset_default_graph()
# prepare data
data = imread('ILSVRC2012_test_00000003.JPEG', mode='RGB').astype(np.float)
image = tf.placeholder(tf.float32, [None, None, 3])
# pre-processing image
image_size = inception_resnet_v2.default_image_size
processed_image = inception_preprocessing.preprocess_image(image, image_size, image_size, is_training=False,)
processed_image = tf.expand_dims(processed_image, 0)
# Creates the Inception Resnet V2 model.
arg_scope = inception_resnet_v2_arg_scope()
with slim.arg_scope(arg_scope):
logits, end_points = inception_resnet_v2(processed_image, is_training=False)
probabilities = tf.nn.softmax(logits)
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, './inception_resnet_v2_2016_08_30.ckpt')
print(sess.run(probabilities, feed_dict={image:data}))
任何帮助将不胜感激!
答案 0 :(得分:0)
装饰器将Inception网络包装到以函数prediction
命名的变量范围中。结果,检查点中的变量名不再与图中的变量名匹配。
要验证这一点,可以在装饰器中将tf.variable_scope()
更改为tf.name_scope()
。在大多数情况下,这也不应影响程序的其余部分。
如果需要变量作用域,可以将dict传递到tf.train.Saver()
中,以将检查点中的变量名映射到图中的变量对象。
也可以通过使用tf.python.pywrap_tensorflow. NewCheckpointReader()
在检查点中读取变量名来实现此目的,但是我没有准备共享的代码示例。