TFRecord获取不匹配的配对数据

时间:2019-03-13 14:56:59

标签: python python-3.x tensorflow tfrecord

我从2个域A和B中获得了大量匹配的配对图像。我将这些图像保存到一些tfrecord文件中,但是当我从文件中加载配对数据时,它们不再匹配。

这是我的保存代码:

def save_tfrecords(paths, desdir):
    cnt_file_idx = 0
    cnt_data_idx = 0

    filename = os.path.join(desdir, 'data%d.tfrecords' % cnt_file_idx)
    filename_list = [filename]

    writer = tf.python_io.TFRecordWriter(filename)
    for i, path in enumerate(paths):
        data = np.load(path)
        data_shape = np.shape(data)

        width = data_shape[1]  # [height, width, channels]
        a_image = np.array(data[:, :width // 2])
        b_image = np.array(data[:, width // 2:])

        # until here I have got correct image pairs

        features = tf.train.Features(
            feature={
                "A": tf.train.Feature(float_list=tf.train.FloatList(value=a_image.reshape(-1))),
                "B": tf.train.Feature(float_list=tf.train.FloatList(value=b_image.reshape(-1))),
                "a_shape": tf.train.Feature(int64_list=tf.train.Int64List(value=np.shape(a_image))),
                "b_shape": tf.train.Feature(int64_list=tf.train.Int64List(value=np.shape(b_image)))
            }
        )
        example = tf.train.Example(features=features)
        serialized = example.SerializeToString()
        writer.write(serialized)

        cnt_data_idx += 1
        if cnt_data_idx == 500:
            writer.close()
            cnt_file_idx += 1
            cnt_data_idx = 0
            filename = os.path.join(desdir, 'data%d.tfrecords' % cnt_file_idx)
            filename_list.append(filename)
            writer = tf.python_io.TFRecordWriter(filename)
    writer.close()
    return filename_list

和我的加载代码:

def load_example(path):  # return 2 iterator (not initialized)
    def pares_tf(example_proto):
        features = {
            "A": tf.VarLenFeature(dtype=tf.float32),
            "B": tf.VarLenFeature(dtype=tf.float32),

            "a_shape": tf.FixedLenFeature(shape=(2,), dtype=tf.int64),
            "b_shape": tf.FixedLenFeature(shape=(2,), dtype=tf.int64)
        }

        parsed_example = tf.parse_single_example(serialized=example_proto, features=features)

        parsed_example['A'] = tf.sparse_tensor_to_dense(parsed_example['A'])
        parsed_example['B'] = tf.sparse_tensor_to_dense(parsed_example['B'])

        parsed_example['A'] = tf.reshape(parsed_example['A'], parsed_example['a_shape'])
        parsed_example['A'] = tf.expand_dims(parsed_example['A'], -1)
        parsed_example['B'] = tf.reshape(parsed_example['B'], parsed_example['b_shape'])
        parsed_example['B'] = tf.expand_dims(parsed_example['B'], -1)

        return parsed_example

    tf_data_dir = os.path.join(path, 'train', 'pair', 'tf_data')
    tf_filename_list = glob.glob(os.path.join(tf_data_dir, "*.tfrecords"))

    dataset_train = tf.data.TFRecordDataset(tf_filename_list)
    dataset_train = dataset_train.map(pares_tf).repeat().batch(32)

    iterator_train = dataset_train.make_initializable_iterator()

    return iterator_train

另一个令人困惑的事情是,当我在加载数据时将批量大小设置为与tfrecords中保存的数据数量相同(这里是500)时,图像对似乎是正确的,但是如果我设置了批量大小为499,则不匹配的固定距离为1,即来自域A 第i 图片与(i + 1 ),然后从域B 中获取图像,并且如果批量大小为498,则不匹配的距离为距离2(A中的i与B中的i + 2配对),等等。< / p>

我很困惑为什么会发生这种情况。谁能帮我解决这个问题?

0 个答案:

没有答案