我有一个基于.tfrecord文件的tensorflow数据集。如何将数据集拆分为测试和训练数据集?例如。 70%的训练和30%的测试?
编辑:
我的Tensorflow版本:1.8 我已经检查过,可能的重复项中没有提到“ split_v”函数。另外,我正在使用tfrecord文件。
答案 0 :(得分:8)
您可以使用Dataset.take()
和Dataset.skip()
:
train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)
full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
full_dataset = full_dataset.shuffle()
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(val_size)
test_dataset = test_dataset.take(test_size)
为了更笼统,我举了一个使用70/15/15火车/ val /测试划分的示例,但是如果您不需要测试或val集,则只需忽略最后两行即可。
Take:
从该数据集中创建一个最多包含count个元素的数据集。
Skip:
创建一个数据集,该数据集从该数据集中跳过计数元素。
您可能还想研究Dataset.shard()
:
创建一个仅包含此数据集1 / num_shards的数据集。
答案 1 :(得分:2)
这个问题类似于this one和this one,但恐怕我们还没有令人满意的答案。
使用take()
和skip()
需要知道数据集的大小。如果我不知道或不想知道该怎么办?
使用shard()
仅给出1 / num_shards
数据集。如果我想要其余的东西怎么办?
我尝试在下面提供一种更好的解决方案,该解决方案仅在 TensorFlow 2 上进行了测试。假设您已经有一个 shuffled 数据集,则可以使用filter()
将其分为两个部分:
import tensorflow as tf
all = tf.data.Dataset.from_tensor_slices(list(range(1, 21))) \
.shuffle(10, reshuffle_each_iteration=False)
test_dataset = all.enumerate() \
.filter(lambda x,y: x % 4 == 0) \
.map(lambda x,y: y)
train_dataset = all.enumerate() \
.filter(lambda x,y: x % 4 != 0) \
.map(lambda x,y: y)
for i in test_dataset:
print(i)
print()
for i in train_dataset:
print(i)
参数reshuffle_each_iteration=False
很重要。它可以确保原始数据集不会被随机洗一次。否则,两个结果集可能会有些重叠。
使用enumerate()
添加索引。
使用filter(lambda x,y: x % 4 == 0)
抽取4个样本中的1个。同样,x % 4 != 0
抽取4个样本中的3个。
使用map(lambda x,y: y)
去除索引并恢复原始样本。
此示例实现了75/25的拆分。
x % 5 == 0
和x % 5 != 0
进行80/20分割。
如果您确实希望以70/30的比例进行拆分,则x % 10 < 3
和x % 10 >= 3
应该这样做。
更新:
从TensorFlow 2.0.0开始,由于AutoGraph's limitations,以上代码可能会导致一些警告。要消除这些警告,请分别声明所有lambda函数:
def is_test(x, y):
return x % 4 == 0
def is_train(x, y):
return not is_test(x, y)
recover = lambda x,y: y
test_dataset = all.enumerate() \
.filter(is_test) \
.map(recover)
train_dataset = all.enumerate() \
.filter(is_train) \
.map(recover)
这对我的机器没有任何警告。将is_train()
设为not is_test()
绝对是一个好习惯。