在tensorlfow数据集中,如何混合2个数据集,分别从原始数据中获取75%的数据集,从扩充数据中获取25%的数据集?
d = tf.data.Dataset.list_files("raw_data/")\
.flat_map(tf.data.TFRecordDataset)
ad = tf.data.Dataset.list_files("augmented_data/")\
.flat_map(tf.data.TFRecordDataset)
答案 0 :(得分:1)
问题是您不能在数据集对象上使用len()
,因此有时很难知道确切的示例数,直到您迭代整个历元。但是您可以使用take
和skip
方法对此进行近似。
train_dataset = dataset.take(number_examples_for_train)
test_dataset = dataset.skip(number_examples_for_train)
这些方法是彼此直接替代的。 https://www.tensorflow.org/api_docs/python/tf/data/Dataset#take