我知道这是一个常见的错误,但我没有理解这个问题。这是我的代码:
def convert_image(url):
checkpoint_file = './vgg_16.ckpt'
input_tensor = tf.placeholder(tf.float32, shape=(None,224,224,3), name='input_image')
scaled_input_tensor = tf.scalar_mul((1.0/255), input_tensor)
scaled_input_tensor = tf.subtract(scaled_input_tensor, 0.5)
scaled_input_tensor = tf.multiply(scaled_input_tensor, 2.0)
#Load the model
sess = tf.Session()
arg_scope = vgg_arg_scope()
with slim.arg_scope(arg_scope):
logits, end_points = vgg_16(scaled_input_tensor, is_training=False)
saver = tf.train.Saver()
saver.restore(sess, checkpoint_file)
response = requests.get(url)
img = Image.open(BytesIO(response.content))
im = np.array(img, dtype='float32')
im = im.reshape(-1,224,224,3)
features = sess.run(end_points['vgg_16/fc7'], feed_dict={input_tensor: im})
sess.close()
return np.squeeze(features)
如您所见,我正在使用VGG_16预训练模型来提取fc7功能。大约50%的代码只是从URL获取图像并将其转换为224x224x3;另外50%的张量流确实能够实现特征表示。
问题是,我第一次运行此代码时工作正常。但是,第二次,我得到了上述错误。当然,“im”是一个float32,即使我收到这个错误。因此,我认为这个问题与我第二次运行此功能时出现的问题有关。如果我不得不猜测,它与“保护程序”的工作方式有关,但我无法弄明白究竟是什么。
有什么想法吗?
答案 0 :(得分:1)
错误很可能是由于您重新定义了input_tensor,而不是在VGG模型中使用输入占位符。您可以在输入图像im
之前应用转换,然后再将其转换为网络。
此外,您为每个图像加载模型。 相反,加载模型一次,然后迭代循环内的图像列表。 像这样:
def convert_images(url_list):
# Load the TF model
#.....
# Session, etc.
# Now, go over the list of images one by one
for url in url_list:
image = ... # get image
features = session.run(...) # extract features