有谁知道如何将Tensorflow中数据集API(tf.data.Dataset)创建的数据集拆分为Test and Train?
答案 0 :(得分:18)
假设您有all_dataset
变量tf.data.Dataset
类型:
test_dataset = all_dataset.take(1000)
train_dataset = all_dataset.skip(1000)
测试数据集现在有前1000个元素,其余的用于训练。
答案 1 :(得分:7)
此处大多数答案使用take()
和skip()
,这需要事先了解数据集的大小。这并非总是可能,或者很难/难以确定。
实际上,您可以做的是对数据集进行切片,以使每N条记录中有1条成为验证记录。
要做到这一点,让我们从0-9的简单数据集开始:
dataset = tf.data.Dataset.range(10)
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
现在在我们的示例中,我们将对其进行切片,以使我们得到3/1训练/验证拆分。意思是3条记录将进行训练,然后1条记录进行验证,然后重复。
split = 3
dataset_train = dataset.window(split, split + 1).flat_map(lambda ds: ds)
# [0, 1, 2, 4, 5, 6, 8, 9]
dataset_validation = dataset.skip(split).window(1, split + 1).flat_map(lambda ds: ds)
# [3, 7]
因此,第一个dataset.window(split, split + 1)
说要获取split
个元素的数量(3),然后前进split + 1
个元素,然后重复。 + 1
有效地跳过了我们将在验证数据集中使用的1元素。
flat_map(lambda ds: ds)
是因为window()
批量返回结果,这是我们不想要的。因此,我们将其压平。
然后我们首先获取验证数据,skip(split)
会跳过在第一个训练窗口中获取的元素的前split
个数字(3),因此我们在第四个元素上开始迭代。 window(1, split + 1)
然后抓取1个元素,前进split + 1
(4),然后重复。
关于嵌套数据集的说明:
上面的示例适用于简单的数据集,但是如果嵌套数据集,flat_map()
将产生错误。为了解决这个问题,您可以将flat_map()
换成可以处理简单和嵌套数据集的更复杂的版本:
.flat_map(lambda *ds: ds[0] if len(ds) == 1 else tf.data.Dataset.zip(ds))
答案 2 :(得分:4)
您可以使用Dataset.take()
和Dataset.skip()
:
train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)
full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
full_dataset = full_dataset.shuffle()
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(val_size)
test_dataset = test_dataset.take(test_size)
为了更笼统,我举了一个使用70/15/15火车/ val /测试划分的示例,但是如果您不需要测试或val集,则只需忽略最后两行即可。
Take:
从该数据集中创建一个最多包含count个元素的数据集。
Skip:
创建一个数据集,该数据集从该数据集中跳过计数元素。
您可能还想研究Dataset.shard()
:
创建一个仅包含此数据集1 / num_shards的数据集。
免责声明我在回答this one之后偶然发现了这个问题,所以我以为我会传播爱心
答案 3 :(得分:3)
@ted的答案将引起某些重叠。试试这个。
train_ds_size = int(0.64 * full_ds_size)
valid_ds_size = int(0.16 * full_ds_size)
train_ds = full_ds.take(train_ds_size)
remaining = full_ds.skip(train_ds_size)
valid_ds = remaining.take(valid_ds_size)
test_ds = remaining.skip(valid_ds_size)
使用下面的代码进行测试。
tf.enable_eager_execution()
dataset = tf.data.Dataset.range(100)
train_size = 20
valid_size = 30
test_size = 50
train = dataset.take(train_size)
remaining = dataset.skip(train_size)
valid = remaining.take(valid_size)
test = remaining.skip(valid_size)
for i in train:
print(i)
for i in valid:
print(i)
for i in test:
print(i)
答案 4 :(得分:0)
现在Tensorflow不包含任何工具。
您可以使用sklearn.model_selection.train_test_split
生成训练/评估/测试数据集,然后分别创建tf.data.Dataset
。
答案 5 :(得分:0)
您可以使用shard
:
dataset = dataset.shuffle() # optional
trainset = dataset.shard(2, 0)
testset = dataset.shard(2, 1)
请参阅: https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shard
答案 6 :(得分:0)
在已知数据集大小的情况下:
from typing import Tuple
import tensorflow as tf
def split_dataset(dataset: tf.data.Dataset,
dataset_size: int,
train_ratio: float,
validation_ratio: float) -> Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]:
assert (train_ratio + validation_ratio) < 1
train_count = int(dataset_size * train_ratio)
validation_count = int(dataset_size * validation_ratio)
test_count = dataset_size - (train_count + validation_count)
dataset = dataset.shuffle(dataset_size)
train_dataset = dataset.take(train_count)
validation_dataset = dataset.skip(train_count).take(validation_count)
test_dataset = dataset.skip(validation_count + train_count).take(test_count)
return train_dataset, validation_dataset, test_dataset
示例:
size_of_ds = 1001
train_ratio = 0.6
val_ratio = 0.2
ds = tf.data.Dataset.from_tensor_slices(list(range(size_of_ds)))
train_ds, val_ds, test_ds = split_dataset(ds, size_of_ds, train_ratio, val_ratio)
答案 7 :(得分:-1)
@ apatsekin,@ ted最近我的声誉不超过50,所以我只需要在这里回答答案,我想直接使用.take方法获取或不获取测试数据集是否合理。如果数据集在每个纪元都经过了改组,那么它将得到不同的TRAIN / TEST划分,因为在训练过程中,我们需要测试集永远不会出现在训练集中。所以这应该是一个问题
或者我们在shuffle中添加一个参数:
train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)
full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
full_dataset = full_dataset.shuffle( reshuffle_each_iteration = False )
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(val_size)
test_dataset = test_dataset.take(test_size)
答案 8 :(得分:-2)
无法发表评论,但以上答案有重叠且不正确。将BUFFER_SIZE设置为DATASET_SIZE以获得完美的随机播放。尝试使用其他大小的val / test大小进行验证。答案应该是:
DATASET_SIZE = tf.data.experimental.cardinality(full_dataset).numpy()
train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)
full_dataset = full_dataset.shuffle(BUFFER_SIZE)
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.take(val_size)
test_dataset = test_dataset.skip(val_size)