Tensorflow Dataset API如何订购list_files?

时间:2017-10-28 13:09:58

标签: tensorflow tensorflow-datasets

我正在使用数据集API list_files来获取source目录和target目录中的文件列表,例如:

source_path = '/tmp/data/source/*.ext1'
target_path = '/tmp/data/target/*.ext2'
source_dataset = tf.data.Dataset.list_files(source_path)
target_dataset = tf.data.Dataset.list_files(data_path)
dataset = tf.data.Dataset.zip((source_dataset, target_dataset))

源和目标目录内容具有相同的顺序文件名,但具有不同的扩展名(例如,源0001.ext1< - >目标0001.ext2)。

但由于list_files无论如何都没有排序,因此压缩数据集包含源和目标之间的不匹配。

如何在新数据集API中解决此问题?

2 个答案:

答案 0 :(得分:0)

此方法的默认行为是按不确定的随机混排顺序返回文件名。传递种子或shuffle = False以获得确定性的结果。

source_dataset = tf.data.Dataset.list_files(source_path, shuffle=False)

val = 5
source_dataset = tf.data.Dataset.list_files(source_path, seed = val)
target_dataset = tf.data.Dataset.list_files(data_path, seed = val)

答案 1 :(得分:0)

我遇到了同样的问题,我通过先对文件路径进行排序来解决它。

我的文件在 OP 的情况下命名为:

input image       -> corresponding output
data/mband/01.tif -> data/gt_mband/01.tif
data/mband/02.tif -> data/gt_mband/02.tif

代码如下:

from pathlib import Path
import tensorflow as tf

DATA_PATH = Path("data")

# Sort the PATHS
img_paths = sorted(map(str, (DATA_PATH / 'mband').glob('*.tif')))
mask_paths = sorted(map(str, (DATA_PATH / 'gt_mband').glob('*.tif')))

# These are tensors of PATHS
# Paths are strings, so order will be preserved
img_paths = tf.data.Dataset.from_tensor_slices(img_paths)
mask_paths = tf.data.Dataset.from_tensor_slices(mask_paths)

# Load the actual images
def parse_image(image_path: 'some_tensor'):
    # Load the image somehow...
    return image_as_tensor

imgs = img_paths.map(parse_image)
masks = mask_paths.map(parse_mask)