在CNN的图像管道中使用`tf.to_float()`或`tf.image.convert_image_dtype()`?

时间:2018-01-12 04:12:26

标签: tensorflow tensorflow-slim

我正在使用此文件作为模板vgg_preprocessing.py修改tf.slim示例。

当我使用tf.slim笔记本(slim_walkthrough.ipynb)中的剪辑从TFRecord文件中读取数据时,我得到的图像颜色失真。当预处理脚本使用tf.to_float()将图像张量从tf.uint8更改为tf.float32时,就会发生这种情况。

image = tf.to_float(image) enter image description here

image = tf.image.convert_image_dtype(image, dtype=tf.float32) enter image description here

通过CNN运行后差异是否重要?如果是这样,哪一个更适合Vgg16图像处理管道?如果我切换到不同的预训练模型,如Inception

,这是否重要?

以下是完整的方法:

# tf.to_float() and tf.image.convert_image_dtype() give different results
def preprocess_for_train(image,
                     output_height,
                     output_width):
  # randomly crop to 224x244
  image = _random_crop([image], output_height, output_width)[0]
  image.set_shape([output_height, output_width, 3])

  image = tf.to_float(image)
  # image = tf.image.convert_image_dtype(image, dtype=tf.float32)

  image = tf.image.random_flip_left_right(image)
  return image

2 个答案:

答案 0 :(得分:0)

我意识到我的问题完全不同了。

上述问题的答案是:

  • tf.to_float([1,2,3])仅生成[1.,2.,3.]
  • tf.image.convert_image_dtype([image tensor with dtype=tf.uint8], dtype=tf.float32)生成一个图像张量,已经标准化为[0..1]
  • 之间的值

但我的错误是因为matplotlib.pyplot.imshow(image)不适用dtype=tf.float32导致的mean_image_subtraction的{​​{1}}的负值。我发现将值转换回Vgg16似乎解决了uint8

的所有问题

imshow()

答案 1 :(得分:0)

首先,请参见代码说明:

img_tensor = tf.image.decode_jpeg(img_raw)
print(img_tensor.shape)
print(img_tensor.dtype)
print(img_tensor.numpy().max())

a = tf.image.convert_image_dtype(img_tensor, dtype=tf.float32)
print(a.numpy().max())
print(a.shape)
print(a.dtype)

b = tf.to_float(img_tensor)
print(b.numpy().max())
print(b.shape)
print(b.dtype)

c = tf.cast(img_tensor,dtype=tf.float32)
print(c.numpy().max())
print(c.shape)
print(c.dtype)

结果是:

(28, 28, 3)
<dtype: 'uint8'>
149

## for tf.image.convert_image_dtype
0.58431375
(28, 28, 3)
<dtype: 'float32'>

## for tf.to_float
WARNING:tensorflow:From <ipython-input-6-c51a71006d6e>:13: to_float (from 
tensorflow.python.ops.math_ops) is deprecated and will be removed in a future 
version.
Instructions for updating:
Use tf.cast instead.
149.0
(28, 28, 3)
<dtype: 'float32'>

## for tf.cast 
149.0
(28, 28, 3)
<dtype: 'float32'>

从上面的代码和结果中,您可以获得

  1. tf.to_float已过时,因此建议使用tf.cast;
  2. tf.to_float相乘1 / 255.0等于tf.image.convert_image_dtype操作;

所以,在我看来,没有太大差异。

顺便说一下,TF版本是:1.13.1。