如何拆分Tensorflow数据集?

时间:2018-07-01 17:00:31

标签: tensorflow tensorflow-datasets

我有一个基于.tfrecord文件的tensorflow数据集。如何将数据集拆分为测试和训练数据集?例如。 70%的训练和30%的测试?

编辑:

我的Tensorflow版本:1.8 我已经检查过,可能的重复项中没有提到“ split_v”函数。另外,我正在使用tfrecord文件。

2 个答案:

答案 0 :(得分:8)

您可以使用Dataset.take()Dataset.skip()

train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)

full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
full_dataset = full_dataset.shuffle()
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(val_size)
test_dataset = test_dataset.take(test_size)

为了更笼统,我举了一个使用70/15/15火车/ val /测试划分的示例,但是如果您不需要测试或val集,则只需忽略最后两行即可。

Take

  

从该数据集中创建一个最多包含count个元素的数据集。

Skip

  

创建一个数据集,该数据集从该数据集中跳过计数元素。

您可能还想研究Dataset.shard()

  

创建一个仅包含此数据集1 / num_shards的数据集。

答案 1 :(得分:2)

这个问题类似于this onethis one,但恐怕我们还没有令人满意的答案。

  • 使用take()skip()需要知道数据集的大小。如果我不知道或不想知道该怎么办?

  • 使用shard()仅给出1 / num_shards数据集。如果我想要其余的东西怎么办?

我尝试在下面提供一种更好的解决方案,该解决方案仅在 TensorFlow 2 上进行了测试。假设您已经有一个 shuffled 数据集,则可以使用filter()将其分为两个部分:

import tensorflow as tf

all = tf.data.Dataset.from_tensor_slices(list(range(1, 21))) \
        .shuffle(10, reshuffle_each_iteration=False)

test_dataset = all.enumerate() \
                    .filter(lambda x,y: x % 4 == 0) \
                    .map(lambda x,y: y)

train_dataset = all.enumerate() \
                    .filter(lambda x,y: x % 4 != 0) \
                    .map(lambda x,y: y)

for i in test_dataset:
    print(i)

print()

for i in train_dataset:
    print(i)

参数reshuffle_each_iteration=False很重要。它可以确保原始数据集不会被随机洗一次。否则,两个结果集可能会有些重叠。

使用enumerate()添加索引。

使用filter(lambda x,y: x % 4 == 0)抽取4个样本中的1个。同样,x % 4 != 0抽取4个样本中的3个。

使用map(lambda x,y: y)去除索引并恢复原始样本。

此示例实现了75/25的拆分。

x % 5 == 0x % 5 != 0进行80/20分割。

如果您确实希望以70/30的比例进行拆分,则x % 10 < 3x % 10 >= 3应该这样做。

更新:

从TensorFlow 2.0.0开始,由于AutoGraph's limitations,以上代码可能会导致一些警告。要消除这些警告,请分别声明所有lambda函数:

def is_test(x, y):
    return x % 4 == 0

def is_train(x, y):
    return not is_test(x, y)

recover = lambda x,y: y

test_dataset = all.enumerate() \
                    .filter(is_test) \
                    .map(recover)

train_dataset = all.enumerate() \
                    .filter(is_train) \
                    .map(recover)

这对我的机器没有任何警告。将is_train()设为not is_test()绝对是一个好习惯。