在Tensorflow数据集API中拆分数据集问题

时间:2018-12-30 15:07:59

标签: python tensorflow tensorflow-datasets

我正在使用tf.contrib.data.make_csv_dataset读取csv文件以形成数据集,然后使用命令take()形成仅具有一个元素的另一个数据集,但仍返回所有元素。

这是怎么了?我带来了以下代码:

import tensorflow as tf
import os
tf.enable_eager_execution()

# Constants

column_names = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species']
class_names = ['Iris setosa', 'Iris versicolor', 'Iris virginica']
batch_size   = 1
feature_names = column_names[:-1]
label_name = column_names[-1]

# to reorient data strucute
def pack_features_vector(features, labels):
  """Pack the features into a single array."""
  features = tf.stack(list(features.values()), axis=1)
  return features, labels

# Download the file
train_dataset_url = "http://download.tensorflow.org/data/iris_training.csv"
train_dataset_fp = tf.keras.utils.get_file(fname=os.path.basename(train_dataset_url),
                                       origin=train_dataset_url)

# form the dataset
train_dataset = tf.contrib.data.make_csv_dataset(
train_dataset_fp,
batch_size, 
column_names=column_names,
label_name=label_name,
num_epochs=1)

# perform the mapping
train_dataset = train_dataset.map(pack_features_vector)

# construct a databse with one element 
train_dataset= train_dataset.take(1)

# inspect elements
for step in range(10):
    features, labels = next(iter(train_dataset))
    print(list(features))

1 个答案:

答案 0 :(得分:0)

基于this的答案,我们可以使用Dataset.take()Dataset.skip()拆分数据集:

train_size = int(0.7 * DATASET_SIZE)

train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)

如何修复代码?

使用一个迭代器代替在循环中多次创建迭代器:

# inspect elements
for feature, label in train_dataset:
    print(feature)

您的代码中发生什么导致这种行为?

1)内置的python iter函数从对象获取迭代器,或者对象本身必须提供自己的迭代器。因此,当您致电iter(train_dataset)时,等同于致电Dataset.make_one_shot_iterator()

2)默认情况下,在tf.contrib.data.make_csv_dataset()中,shuffle参数为True(shuffle=True)。因此,每次调用iter(train_dataset)时,它都会创建一个包含不同数据的新Iterator。

3)最后,在通过for step in range(10)进行循环时,类似于您创建10个大小为1的不同迭代器,每个迭代器都有自己的数据,因为它们被混洗了。

建议:如果要避免此类事情,请在循环外初始化(创建)迭代器:

train_dataset = train_dataset.take(1)
iterator = train_dataset.make_one_shot_iterator()
# inspect elements
for step in range(10):
    features, labels = next(iterator)
    print(list(features))
    # throws exception because size of iterator is 1