tfrecord掩盖未正确完成?

时间:2018-10-01 09:32:31

标签: tensorflow-datasets

我正在尝试将随机的numpy数组转换为tf.records。但它似乎没有正确执行。标签转换很好,但是图像转换不会返回原始图像。

最后,它会打印False,True,而它应该是True,True。我想知道为什么会这样?这是张量流还是我错过了什么?

将tensorflow导入为tf     将numpy导入为np

def wrap_bytes(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def wrap_int64(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def convert(images, labels, out_path):

    num_images = len(labels)
    with tf.python_io.TFRecordWriter(out_path) as writer:
        for i in range(num_images):

            label_ = labels[i]
            # the same problem persists whether or not we flatten
            image = images[i].flatten()
            image_bytes = image.tostring()
            features = \
                {
                    'image': wrap_bytes(tf.compat.as_bytes(image_bytes)),
                    'label': wrap_int64(label_)
                }

            feature = tf.train.Features(feature=features)
            example = tf.train.Example(features=feature)
            serialized = example.SerializeToString()
            writer.write(serialized)


def parse(serialized):
    features = \
        {
            'image': tf.FixedLenFeature([], tf.string),
            'label': tf.FixedLenFeature([], tf.int64),
        }

    parsed_example = \
        tf.parse_single_example(
            serialized=serialized,
            features=features)

    image_raw = parsed_example['image']
    label_raw = parsed_example['label']
    image_ = tf.decode_raw(image_raw, tf.int32)
    image_reshaped = tf.reshape(image_, (5, 5))

    return image_reshaped, label_raw


def input_fn(filenames, batch_size):

    dataset = tf.data.TFRecordDataset(filenames=filenames)
    dataset = dataset.map(parse)
    # dataset = dataset.repeat(1)
    dataset = dataset.batch(batch_size)
    iterator = dataset.make_one_shot_iterator()
    batch_images_tf, batch_labels_tf = iterator.get_next()

    return batch_images_tf, batch_labels_tf


n = 10
num_classes = 15
batch_size = 2
out_path = 'bug.tfrecords'
labels = np.random.randint(0, num_classes, n)

image_shape = (5, 5)
images_ = np.int32(np.random.randint(0, 255, 5*5*n).reshape(n, 5, 5))

convert(images_, labels, out_path)

batch_images_tf, batch_labels_tf = input_fn(out_path, batch_size)

sess = tf.Session()
batch_labels_np = sess.run(batch_labels_tf)
batch_images_np = sess.run(batch_images_tf)

# checking whether the converted data is the same as the original
print(np.array_equal(batch_images_np, images_[0:batch_size]))
print(np.array_equal(batch_labels_np, labels[0:batch_size]))

0 个答案:

没有答案