使用Tensorflow中的CSV-Reader和队列输入序列

时间:2016-12-19 09:30:47

标签: python csv tensorflow

参考:https://www.tensorflow.org/how_tos/reading_data/

目标:

我想使用以下形式的前面序列训练LSTM: [t0 t1 t2],[t1 t2 t3],[t2 t3 t4] ......

此外,这些序列也应该改组。 例如[t2 t3 t4],[t0 t1 t2],[t1 t2 t3] ...

我的数据存储在csv文件中,每行代表一个时间步。这些列包含不同的功能和目标值。

问题:

有没有办法在Tensorflow中使用 csv-reader和队列(不是占位符和feed_dict)来提供混洗的连贯序列? 我想不出用tf.TextLineReader()和tf.train.shuffle_batch()来实现它的方法。

我的解决方法做了它应该做的事情,但速度非常慢:

train_filename_queue = tf.train.string_input_producer([path])
rand_ind_q = tf.train.range_input_producer(data_len-seq_len, shuffle=True)

def read_csv(filename_queue, ncols, header_lines):
''' returns a list of tensors with content of csv-file
'''
    # content <- [(data_len,) ... ncols ... (data_len,)]

    whole_reader = tf.WholeFileReader()
    _, content = whole_reader.read(filename_queue)
    content = tf.string_split([content], delimiter='\n').values[header_lines:]  
    record_defaults = ncols*[[0.]]
    content = tf.decode_csv(content, record_defaults, field_delim=',')
    return content

def slice_seq(q, content, seq_len):
''' returns a list of tensors with sequences
'''
    # seq <- [(1,seq_len,) ... ncols ... (1,seq_len,)]

    start_ind = q.dequeue()

    seq = list(map(lambda tensor: tf.slice(tensor, [start_ind], [seq_len]), content))
    seq = list(map(lambda tensor: tf.reshape(tensor, (1,seq_len,)), seq))
    return seq    

1 个答案:

答案 0 :(得分:0)

您要加载的csv文件有多大?如果它很大,那么使用WholeFileReader是有问题的;如果它很小,那么我建议在TensorFlow之外创建序列。

那就是说,如果你想在TensorFlow中这样做,你可以试试

block_size = 10
seq_len = 3
batch_size = 4
queue_capacity = 32
num_threads = 4
csv_path = '/path/to/some.csv'

filename_queue = tf.train.string_input_producer([csv_path])
block = tf.TextLineReader().read_up_to(filename_queue, block_size).values
subsequences = [tf.slice(block, [i], [block_size - seq_len + 1]) for i in range(seq_len)]
batched = tf.train.shuffle_batch(subsequences,
                                 batch_size, 
                                 capacity=queue_capacity,
                                 min_after_dequeue=queue_capacity/2,     
                                 num_threads=num_threads,
                                 enqueue_many=True)
decoded = [tf.decode_csv(b, [[0.]] * ncols, ',') for b in batched]

这里唯一的缺点是你会大致错过子序列的seq_len / batch_size。

希望这有帮助。