python-为Keras LSTM读取多个CSV

时间:2019-02-26 06:21:06

标签: python csv keras lstm

我正在尝试使用Keras实现LSTM网络,但是在输入时遇到了问题。 我的数据集采用多个CSV文件的形式(所有文件的尺寸均为68x250,每个条目均包含2个值)。各个类别之间大约有200个CSV文件。 Preview of one of the CSVs

如何将这些多个CSV作为输入?

2 个答案:

答案 0 :(得分:1)

定义一个实现以下接口定义的类: https://keras.io/utils/#sequence

并使用model.fit_generator方法。

答案 1 :(得分:1)

最近我做了类似的事情,因为佩德罗说你应该使用fit_generator并编写自定义生成器。

以下是生成器的示例:

def generator(files):
    print('start generator')
    while 1:        
        print('loop generator')
        for file in files:
            try:                 
                df = pd.read_csv(file)
                batches = int(np.ceil(len(df)/batch_size))      
                for i in range(0, batches):                                   
                    yield pad_batch(df[i*batch_size:min(len(df), i*batch_size+batch_size)])

            except EOFError:
                print("error" + file) 

将文件名列表传递给生成器的位置,然后遍历文件并分批返回内容。在我的情况下,load_data是一个函数,它读取熊猫中的csvs并进行一些预处理。 pad_batch对LSTM进行填充。

用法:

model.fit_generator(
      generator=generator(trainingFiles),   
      steps_per_epoch=steps,
      epochs=num_epochs,
      validation_data=[x_test, y_test],
      verbose=1)