Tensorflow:您必须使用dtype float

时间:2017-08-04 02:44:11

标签: python tensorflow

我知道这是一个常见的错误,但我没有理解这个问题。这是我的代码:

def convert_image(url):

    checkpoint_file = './vgg_16.ckpt'

    input_tensor = tf.placeholder(tf.float32, shape=(None,224,224,3), name='input_image')
    scaled_input_tensor = tf.scalar_mul((1.0/255), input_tensor)
    scaled_input_tensor = tf.subtract(scaled_input_tensor, 0.5)
    scaled_input_tensor = tf.multiply(scaled_input_tensor, 2.0)

    #Load the model
    sess = tf.Session()
    arg_scope = vgg_arg_scope()
    with slim.arg_scope(arg_scope):
        logits, end_points = vgg_16(scaled_input_tensor, is_training=False)
    saver = tf.train.Saver()
    saver.restore(sess, checkpoint_file)

    response = requests.get(url)
    img = Image.open(BytesIO(response.content))
    im = np.array(img, dtype='float32')
    im = im.reshape(-1,224,224,3)

    features = sess.run(end_points['vgg_16/fc7'], feed_dict={input_tensor: im})
    sess.close()
    return np.squeeze(features)

如您所见,我正在使用VGG_16预训练模型来提取fc7功能。大约50%的代码只是从URL获取图像并将其转换为224x224x3;另外50%的张量流确实能够实现特征表示。

问题是,我第一次运行此代码时工作正常。但是,第二次,我得到了上述错误。当然,“im”是一个float32,即使我收到这个错误。因此,我认为这个问题与我第二次运行此功能时出现的问题有关。如果我不得不猜测,它与“保护程序”的工作方式有关,但我无法弄明白究竟是什么。

有什么想法吗?

1 个答案:

答案 0 :(得分:1)

错误很可能是由于您重新定义了input_tensor,而不是在VGG模型中使用输入占位符。您可以在输入图像im之前应用转换,然后再将其转换为网络。

此外,您为每个图像加载模型。 相反,加载模型一次,然后迭代循环内的图像列表。 像这样:

def convert_images(url_list):
   # Load the TF model
   #.....
   # Session, etc.

   # Now, go over the list of images one by one
   for url in url_list:
      image = ... # get image
      features = session.run(...) # extract features