Tensorflow的decode_csv只读一行

时间:2017-08-10 21:28:56

标签: python tensorflow

如何让decode_csv函数读取CSV中的每一行?

我目前正在尝试将CS​​V文件中的数据加载到GPU上。数据加载到GPU上,除了......实际上只读取了我的640行CSV文件的一行。你认为我哪里出错?

import tensorflow as tf

with tf.device('/gpu:0'):
    filename_queue = tf.train.string_input_producer(['dataset.csv'])
    reader = tf.TextLineReader()
    key, value = reader.read(filename_queue)

    record_defaults = [['']]*121
    all_columns = tf.decode_csv(value, record_defaults=record_defaults)

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        # Start populating the filename queue.
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        # Iterate through all the columns
        vals = []
        for x in range(121):
            tmp = all_columns.pop()
            myval = tmp.eval(session=sess)
            vals.append(myval)

        coord.request_stop()
        coord.join(threads)

然后,如果我......

>>> import numpy as np
>>> vals = np.asarray(vals)
>>> vals.shape
(121,)

我的CSV中的640行每个都有121列。 vals中的值看起来很好,除了我实际上并没有读取所有640行。我猜这与它有关:

all_columns = tf.decode_csv(value, record_defaults=record_defaults)

1 个答案:

答案 0 :(得分:0)

NVM。想出来了。

显然,sess.run()pop()在提取行数据方面存在差异。

我的CSV文件中有640行和121列,因此:

record_defaults = [['']]*121

for x in range(640):

请注意,这主要是为了测试而硬编码。解决方案如下:

import tensorflow as tf

with tf.device('/gpu:0'):
filename_queue = tf.train.string_input_producer(['../Datasets/CMU_face_images_dataset.csv'])
    reader = tf.TextLineReader()
    key, value = reader.read(filename_queue)

    record_defaults = [['']]*121
    all_columns = tf.decode_csv(value, record_defaults=record_defaults)

    # TWO NEW LINES
    name = all_columns[0]
    data = all_columns[1:]

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        # Start populating the filename queue.
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        vals = []
        names = []
        for x in range(640):

            # THIS IS THE NEW LINE
            _name, _val = sess.run([name, data])

            # OLD LINES
            # tmp = all_columns.pop()
            # myval = tmp.eval(session=sess)
            # vals.append(myval)

            names.append(_name)
            vals.append(_val)

        coord.request_stop()
        coord.join(threads)