TensorFlow - 在读取和写入TFRecords文件时设置图像的形状?

时间:2016-03-13 19:36:31

标签: python numpy tensorflow

在尝试使用TFRecords格式时,我遇到了设置图像数据形状的问题。我已经查看了how-to for reading data,并从MNIST示例中获取了converting the image data to a TFRecordsreading the data from the TFRecords的代码。但是,此示例代码最初希望图像以所有像素数据位于一个长向量中的格式使用。

我一直在尝试更改此代码以使用仍处于原始图像形状的NumPy数组。因此,在下面的代码中,images是一个形状为[number_of_images, height, width, channels]的NumPy数组。我不确定我的问题是关于我如何将数据写入TFRecords,或者我是如何将其读回来的。但是,当我尝试设置解码图像的形状时,我得到错误ValueError: Shapes (?,) and (464, 624, 3) must have the same rank(注意:464 x 624 x 3是图像尺寸)。关于我可能做错什么的任何建议?

相关代码(从示例代码略有改动)

def convert_to_tfrecord(images, labels, name, data_directory):
    number_of_examples = labels.shape[0]
    rows = images.shape[1]  # images is the 4D ndarray with the images in their original shape.
    cols = images.shape[2]
    depth = images.shape[3]
    ...
    for index in range(number_of_examples):
        image_raw = images[index].tostring()
        example = tf.train.Example(features=tf.train.Features(feature={
            'height': _int64_feature(rows),
            'width': _int64_feature(cols),
            'channels': _int64_feature(depth),
            'image': _bytes_feature(image_raw),
            ...
        }))
        writer.write(example.SerializeToString())

...

def read_and_decode(filename_queue):
    ...
    features = tf.parse_single_example(
        serialized_example,
        features={
            'image_raw': tf.FixedLenFeature([], tf.string),
            ...
        })
    ...
    image = tf.decode_raw(features['image_raw'], tf.uint8)
    image.set_shape([464, 624, 3])  # This is where the error occurs.
    image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
    ...

1 个答案:

答案 0 :(得分:4)

请注意set_shape不会改变底层缓冲区的形状,它只是设置一个可以在此张量上看到的可能形状集的图形级注释。

要更改实际形状,您需要使用tf.reshape