使用tf.data读取CSV文件非常慢,请改用tfrecords?

时间:2018-05-17 05:33:53

标签: python tensorflow tensorflow-datasets

我有很多CSV文件,每条记录包含~6000列。第一列是标签,其余列应视为特征向量。我是Tensorflow的新手,我无法弄清楚如何将数据读入具有所需格式的Tensorflow Dataset。我目前正在运行以下代码:

DEFAULTS = []
n_features = 6170
for i in range(n_features+1):
  DEFAULTS.append([0.0])

def parse_csv(line):
    # line = line.replace('"', '')
    columns = tf.decode_csv(line, record_defaults=DEFAULTS)  # take a line at a time
    features = {'label': columns[-1], 'x': tf.stack(columns[:-1])}  # create a dictionary out of the features
    labels = features.pop('label')  # define the label

    return features, labels


def train_input_fn(data_file=sample_csv_file, batch_size=128):
    """Generate an input function for the Estimator."""
    # Extract lines from input files using the Dataset API.
    dataset = tf.data.TextLineDataset(data_file)
    dataset = dataset.map(parse_csv)
    dataset = dataset.shuffle(10000).repeat().batch(batch_size)
    return dataset.make_one_shot_iterator().get_next()

每个CSV文件都有~10K记录。我已尝试在train_input_fn上作为labels = train_input_fn()[1].eval(session=sess)进行示例评估。这有128个标签,但它需要 2分钟

我是否使用了一些冗余操作,或者有更好的方法吗?

PS:我在Spark Dataframe中有原始数据。因此,如果可以让事情变得更快,我也可以使用TFRecords。

1 个答案:

答案 0 :(得分:2)

你做得对。但更快的方法是使用TFRecords,如以下步骤所示:

1.使用tf.python_io.TFRecordWriter

要读取csv文件并将其写为tfrecord文件,如下所示:Tensorflow create a tfrecords file from csv

2.从tfrecord中读取:

    def _parse_function(proto):
       f = {
           "features": tf.FixedLenSequenceFeature([], tf.float32, default_value=0.0, allow_missing=True),
           "label": tf.FixedLenSequenceFeature([], tf.float32, default_value=0.0, allow_missing=True)
           }
           parsed_features = tf.parse_single_example(proto, f)
           features = parsed_features["features"]
           label = parsed_features["label"]
           return features, label


    dataset = tf.data.TFRecordDataset(['csv.tfrecords'])
    dataset = dataset.map(_parse_function)
    dataset = dataset = dataset.shuffle(10000).repeat().batch(128)
    iterator = dataset.make_one_shot_iterator()
    features, label = iterator.get_next()

我在随机生成的csv上运行了两个案例(csv vs tfrecords)。 csv直接读取的10个批次(每个128个样本)的总时间约为204s,而tfrecord的总时间约为0.22s