在Tensorflow 1.2.0中读取CSV文件

时间:2017-06-25 03:16:04

标签: python csv tensorflow

我正在尝试批量阅读heart.csv文件数据。根据{{​​3}}网站的文档,我有以下代码逐行阅读

import tensorflow as tf
filename_queue = tf.train.string_input_producer(["heart.csv"])
reader = tf.TextLineReader(skip_header_lines=1)
_, csv_row = reader.read(filename_queue)

record_defaults = [[0], [0.0], [0.0], [0.0], [""], [0], [0.0], [0.0], [0], [0]]
sbp, tobacco, ldl, adiposity, famhist, typea, obesity, alcohol, age, chd = tf.decode_csv(csv_row, record_defaults=record_defaults)
features = [sbp, tobacco, ldl, adiposity, famhist, typea, obesity, alcohol, age]

nof_examples = 10
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    while nof_examples > 0:
        nof_examples -= 1
        try:
            data_features, data_chd = sess.run([features, chd])
#             data_features[4] = 1 if data_features[4] == 'Present' else 0
            print(data_features, data_chd)
        except tf.errors.OutOfRangeError:
            coord.request_stop()
            coord.join(threads)
            break
    coord.request_stop()
    coord.join(threads)

输出:

([160, 12.0, 5.73, 23.110001, 'Present', 49, 25.299999, 97.199997, 52], 1)
([144, 0.0099999998, 4.4099998, 28.610001, 'Absent', 55, 28.870001, 2.0599999, 63], 1)
([118, 0.079999998, 3.48, 32.279999, 'Present', 52, 29.139999, 3.8099999, 46], 0)
([170, 7.5, 6.4099998, 38.029999, 'Present', 51, 31.99, 24.26, 58], 1)
([134, 13.6, 3.5, 27.780001, 'Present', 60, 25.99, 57.34, 49], 1)
([132, 6.1999998, 6.4699998, 36.209999, 'Present', 62, 30.77, 14.14, 45], 0)
([142, 4.0500002, 3.3800001, 16.200001, 'Absent', 59, 20.809999, 2.6199999, 38], 0)
([114, 4.0799999, 4.5900002, 14.6, 'Present', 62, 23.110001, 6.7199998, 58], 1)
([114, 0.0, 3.8299999, 19.4, 'Present', 49, 24.860001, 2.49, 29], 0)
([132, 0.0, 5.8000002, 30.959999, 'Present', 69, 30.110001, 0.0, 53], 1)

但是当我尝试按张量流文档中显示的批量阅读时,我得到了

TypeError: Cannot convert a list containing a tensor of dtype <dtype:
float32'> to <dtype: 'int32'> (Tensor is: <tf.Tensor 'DecodeCSV_6:1'
shape=() dtype=float32>)

批处理代码

import tensorflow as tf
batch_size = 1
def read_my_file_format(filename_queue):
    reader = tf.TextLineReader(skip_header_lines=1)
    _, csv_row = reader.read(filename_queue)
    record_defaults = [[0], [0.0], [0.0], [0.0], [""], [0], [0.0], [0.0], [0], [0]]
    sbp, tobacco, ldl, adiposity, famhist, typea, obesity, alcohol, age, chd = tf.decode_csv(csv_row, record_defaults=record_defaults)
    feature = [sbp, tobacco, ldl, adiposity, famhist, typea, obesity, alcohol, age]
    label = [chd]
    return feature, label

def input_pipeline(filenames, batch_size, num_epochs=None):
    filename_queue = tf.train.string_input_producer(filenames, 
                                                    num_epochs=num_epochs, 
                                                    shuffle=True)
    feature, label = read_my_file_format(filename_queue)
    min_after_dequeue = 10000
    capacity = min_after_dequeue + 3 * batch_size
    feature_batch, label_batch = tf.train.shuffle_batch([feature, label], 
                                                        batch_size=batch_size, 
                                                        capacity=capacity,
                                                        min_after_dequeue=min_after_dequeue)
    return feature_batch, label_batch

features, labels = input_pipeline(['heart.csv'], batch_size)

with tf.Session() as sess:
    tf.global_variables_initializer().run()

    # start populating filename queue
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    try:
        while not coord.should_stop():
            feature_batch, label_batch = sess.run([features, labels])
            print(feature_batch)
    except tf.errors.OutOfRangeError:
        print('Done training, epoch reached')
    finally:
        coord.request_stop()
    coord.join(threads) 

使用tensorflow读取CSV文件似乎有点麻烦,但我确信它在作为分布式系统的库中具有重要性。我发现它令人困惑,花了60多分钟阅读并掌握了读取源管道如何为csv文件工作。可能文档应该更好,需要更多的视觉效果。

1 个答案:

答案 0 :(得分:1)

我查看了代码,看来tf.train.shuffle_batch中的一个内部函数要求行中的所有张量都具有相同的dtype(从第一个元素推断出来,在你身上案例a tf.int32)。你可以用字符串解码它们,然后在正确的类型中转换它们。不太方便。

但是我建议您在使用TensorFlow 1.2.0时建议使用新的DataSet API,这是处理数据时的新方法(参见例如this answer)。

根据引用的答案,以下是使用新API的示例:

def read_row(csv_row):
    record_defaults = [[0], [0.0], [0.0], [0.0], [""], [0], [0.0], [0.0], [0], [0]]
    row = tf.decode_csv(csv_row, record_defaults=record_defaults)
    return row[:-1], row[-1]

def input_pipeline(filenames, batch_size):
    # Define a `tf.contrib.data.Dataset` for iterating over one epoch of the data.
    dataset = (tf.contrib.data.TextLineDataset(filenames)
               .skip(1)
               .map(lambda line: read_row(line))
               .shuffle(buffer_size=10)  # Equivalent to min_after_dequeue=10.
               .batch(batch_size))

    # Return an *initializable* iterator over the dataset, which will allow us to
    # re-initialize it at the beginning of each epoch.
    return dataset.make_initializable_iterator()

iterator = input_pipeline(['heart.csv'], batch_size)
features, labels = iterator.get_next()


nof_examples = 10
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    sess.run(iterator.initializer)
    while nof_examples > 0:
        nof_examples -= 1
        try:
            data_features, data_labels = sess.run([features, labels])
            print(data_features)
        except tf.errors.OutOfRangeError:
            pass