Tensorflow读取多csv需要太多时间

时间:2018-11-26 12:34:25

标签: tensorflow tensorflow-datasets

我计划从多个csv文件读取特征数据。
每个csv文件有150列,批处理大小为256。

读取1000次迭代所需的时间大约为12s。
我觉得这样做的时间成本不应该那么多,这里有人可以提出建议吗?

def _parse_csv_row(*vals):
    features = tf.convert_to_tensor(vals[0:f_size * 5])
    class_label = tf.cast(vals[f_size * 5] + tf.convert_to_tensor(1.0, tf.float64), tf.int64)
    return features, class_label


def get_batch_data(name):
    root_path="g:\\market\\2018-11-12\\feature_{}\\".format(name)
    file_queue = list(map(lambda x: "{}{}".format(root_path, x), fnmatch.filter(os.listdir("g:\\market\\2018-11-12\\feature_{}\\".format(name)), "*.sz_result.csv")))
    record_defaults = [tf.float32] * f_size * 5 + [tf.float64]
    selected_cols = reduce(lambda x, y: x + y, [list(range(1 + x * 29, 1 + x * 29 + 9)) for x in range(0, 5)]) + [146]
    dataset = tf.contrib.data.CsvDataset(
        file_queue,
        record_defaults,
        buffer_size=1024 * 1024 * 10,
        header=True,
        na_value='0.0',
        select_cols=selected_cols)
    dataset = dataset.apply(tf.contrib.data.map_and_batch(
        map_func=_parse_csv_row, batch_size=train_config.BATCH_SIZE))
    dataset = dataset.prefetch(256 * 1024)
    dataset = dataset.repeat()

    dataset = dataset.shuffle(buffer_size=32)
    iterator = dataset.make_one_shot_iterator()
    feature_batch, label_batch = iterator.get_next()
    return feature_batch, label_batch

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
    sess.run(tf.global_variables_initializer())
    a, b = get_batch_data("train")
    start_time = time.time()
    for x in range(1000):
        v = sess.run([a,b])
    print(time.time() - start_time)

1 个答案:

答案 0 :(得分:1)

dataset = dataset.prefetch(256 * 1024)行。它是在map_and_batch操作之后编写的。这意味着您正在预取256 * 1024批次。因此,当您的程序尝试加载第一条记录时,它实际上首先加载了256 * 1024 * 256条记录。可能您的意图是仅预取1024个批次。在现实生活中,仅预取一条记录就足够了。

我将ds.prefetch(1)行作为数据集的最后一个操作。参见Summary of Best Practices

还为读取CSV文件(buffer_size参数)分配了很大的缓冲区。如果您打算缓存整个csv文件,则可以使用ds.cache()操作。如果没有参数,它将内容缓存在内存中。将其放在ds.repeat()操作之前。