我从2个域A和B中获得了大量匹配的配对图像。我将这些图像保存到一些tfrecord文件中,但是当我从文件中加载配对数据时,它们不再匹配。
这是我的保存代码:
def save_tfrecords(paths, desdir):
cnt_file_idx = 0
cnt_data_idx = 0
filename = os.path.join(desdir, 'data%d.tfrecords' % cnt_file_idx)
filename_list = [filename]
writer = tf.python_io.TFRecordWriter(filename)
for i, path in enumerate(paths):
data = np.load(path)
data_shape = np.shape(data)
width = data_shape[1] # [height, width, channels]
a_image = np.array(data[:, :width // 2])
b_image = np.array(data[:, width // 2:])
# until here I have got correct image pairs
features = tf.train.Features(
feature={
"A": tf.train.Feature(float_list=tf.train.FloatList(value=a_image.reshape(-1))),
"B": tf.train.Feature(float_list=tf.train.FloatList(value=b_image.reshape(-1))),
"a_shape": tf.train.Feature(int64_list=tf.train.Int64List(value=np.shape(a_image))),
"b_shape": tf.train.Feature(int64_list=tf.train.Int64List(value=np.shape(b_image)))
}
)
example = tf.train.Example(features=features)
serialized = example.SerializeToString()
writer.write(serialized)
cnt_data_idx += 1
if cnt_data_idx == 500:
writer.close()
cnt_file_idx += 1
cnt_data_idx = 0
filename = os.path.join(desdir, 'data%d.tfrecords' % cnt_file_idx)
filename_list.append(filename)
writer = tf.python_io.TFRecordWriter(filename)
writer.close()
return filename_list
和我的加载代码:
def load_example(path): # return 2 iterator (not initialized)
def pares_tf(example_proto):
features = {
"A": tf.VarLenFeature(dtype=tf.float32),
"B": tf.VarLenFeature(dtype=tf.float32),
"a_shape": tf.FixedLenFeature(shape=(2,), dtype=tf.int64),
"b_shape": tf.FixedLenFeature(shape=(2,), dtype=tf.int64)
}
parsed_example = tf.parse_single_example(serialized=example_proto, features=features)
parsed_example['A'] = tf.sparse_tensor_to_dense(parsed_example['A'])
parsed_example['B'] = tf.sparse_tensor_to_dense(parsed_example['B'])
parsed_example['A'] = tf.reshape(parsed_example['A'], parsed_example['a_shape'])
parsed_example['A'] = tf.expand_dims(parsed_example['A'], -1)
parsed_example['B'] = tf.reshape(parsed_example['B'], parsed_example['b_shape'])
parsed_example['B'] = tf.expand_dims(parsed_example['B'], -1)
return parsed_example
tf_data_dir = os.path.join(path, 'train', 'pair', 'tf_data')
tf_filename_list = glob.glob(os.path.join(tf_data_dir, "*.tfrecords"))
dataset_train = tf.data.TFRecordDataset(tf_filename_list)
dataset_train = dataset_train.map(pares_tf).repeat().batch(32)
iterator_train = dataset_train.make_initializable_iterator()
return iterator_train
另一个令人困惑的事情是,当我在加载数据时将批量大小设置为与tfrecords中保存的数据数量相同(这里是500)时,图像对似乎是正确的,但是如果我设置了批量大小为499,则不匹配的固定距离为1,即来自域A 的第i 图片与(i + 1 ),然后从域B 中获取图像,并且如果批量大小为498,则不匹配的距离为距离2(A中的i与B中的i + 2配对),等等。< / p>
我很困惑为什么会发生这种情况。谁能帮我解决这个问题?