TensorFlow TFRecords按顺序加载图像

时间:2017-03-10 07:51:24

标签: tensorflow

在将我自己的图像集转换为TFRecords文件后,我在逐个加载图像时遇到问题。在火车和测试时间,他们按随机顺序装载,这对于培训来说很棒,但是为了测试,我要求图像逐一出现。

将自己的2D灰度* .png图像转换为TFRecords

我已使用build_image_data.py将我的图片转换为TFRecords,它似乎工作得很好:)

在TFRecords文件中阅读

要从TFRecords文件中读取图像,我使用以下代码

def getImage(filename):
    # convert filenames to a queue for an input pipeline.
    filenameQ = tf.train.string_input_producer([filename], num_epochs=None)

    # object to read records
    recordReader = tf.TFRecordReader()

    # read the full set of features for a single example
    key, fullExample = recordReader.read(filenameQ)

    # parse the full example into its' component features.
    features = tf.parse_single_example(
        fullExample,
        features={
            'image/height': tf.FixedLenFeature([], tf.int64),
            'image/width': tf.FixedLenFeature([], tf.int64),
            'image/colorspace': tf.FixedLenFeature([], dtype=tf.string, default_value=''),
            'image/channels':  tf.FixedLenFeature([], tf.int64),
            'image/class/label': tf.FixedLenFeature([], tf.int64),
            'image/class/text': tf.FixedLenFeature([], dtype=tf.string, default_value=''),
            'image/format': tf.FixedLenFeature([], dtype=tf.string, default_value=''),
            'image/filename': tf.FixedLenFeature([], dtype=tf.string, default_value=''),
            'image/encoded': tf.FixedLenFeature([], dtype=tf.string, default_value='')
        })

# now we are going to manipulate the label and image features
label = features['image/class/label']
image_buffer = features['image/encoded']

# Decode the jpeg
with tf.name_scope('decode_jpeg', [image_buffer], None):
    # decode
    image = tf.image.decode_jpeg(image_buffer, channels=3)

    # and convert to single precision data type
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)

# cast image into a single array, where each element corresponds to the greyscale
# value of a single pixel.
image = tf.reshape(tf.image.rgb_to_grayscale(image), [101 * 201])

# re-define label as a "one-hot" vector
label = tf.pack(tf.one_hot(label - 1, 4))

return label, image


# associate the "label" and "image" objects with the corresponding features read from
# a single example in the training data file
label, image = getImage("../image-to-tfrecords/train-00000-of-00001")

# associate the "label_batch" and "image_batch" objects with a randomly selected batch---
# of labels and images respectively
imageBatch, labelBatch = tf.train.shuffle_batch(
    [image, label], batch_size=100,
    capacity=2000,
    min_after_dequeue=1000)

使用图片进行预测

然后我加载我保存的模型及其权重,predict_op等,并通过以下方式评估预测操作:

with tf.Session() as sess:
    for i in range(1):
        batch_xs = sess.run(image)
        batch_xs = np.reshape(batch_xs, (-1, self.img_w * self.img_h))
        prediction = sess.run([predict_op], feed_dict={x: batch_xs})

其中imagegetImage函数的返回值。

TFRecords的图像目录结构(仅供参考)

├── train
|   ├── chef
|       ├── data0.png
|       ├── ...
|       └── data5467.png
|   ├── chicken
|       ├── data0.png
|       ├── ...
|       └── data2098.png
|   ├── parasaurolophus
|       ├── data0.png
|       ├── ...
|       └── data2977.png
|   └── Trex
|       ├── data0.png
|       ├── ...
|       └── data2841.png
├── validation
|   ├── ...
├── build_image_data.py
└── labels.txt

更新

我试图打印getImage函数接收到的图像的文件名,它从TFRecords文件中以随机顺序获取图像。

第一轮:

[output]:
data4918.png
data4984.png
data1144.png
data2186.png
data138.png
data573.png
data2590.png
data392.png
data846.png
data3222.png

第二轮:

[output]:
data5212.png
data1144.png
data3588.png
data4054.png
data2938.png
data3396.png
data4711.png
data3222.png
data5003.png
data1298.png

每次都不应该以相同的顺序读取图像吗?

1 个答案:

答案 0 :(得分:1)

作为临时解决方案,我创建了以下函数,在不使用TFRecords的情况下读取原始png图像。

然而必须才能用TFRecords做到这一点!如果您有任何想法,请提供答案!

def readImage(filenames):
    filenameQ = tf.train.string_input_producer(filenames, shuffle=False)

    reader = tf.WholeFileReader() # Magic function
    key, value = reader.read(filenameQ)

    image = tf.image.decode_png(value)
    image.set_shape([101, 201, 1])
    return image

image = readImage([("../image-to-tfrecords/train/parasaurolophus/data%d.png" % i) for i in range(1000)])