我正在尝试使用Keras实现LSTM网络,但是在输入时遇到了问题。 我的数据集采用多个CSV文件的形式(所有文件的尺寸均为68x250,每个条目均包含2个值)。各个类别之间大约有200个CSV文件。 Preview of one of the CSVs
如何将这些多个CSV作为输入?
答案 0 :(得分:1)
定义一个实现以下接口定义的类: https://keras.io/utils/#sequence
并使用model.fit_generator方法。
答案 1 :(得分:1)
最近我做了类似的事情,因为佩德罗说你应该使用fit_generator并编写自定义生成器。
以下是生成器的示例:
def generator(files):
print('start generator')
while 1:
print('loop generator')
for file in files:
try:
df = pd.read_csv(file)
batches = int(np.ceil(len(df)/batch_size))
for i in range(0, batches):
yield pad_batch(df[i*batch_size:min(len(df), i*batch_size+batch_size)])
except EOFError:
print("error" + file)
将文件名列表传递给生成器的位置,然后遍历文件并分批返回内容。在我的情况下,load_data
是一个函数,它读取熊猫中的csvs并进行一些预处理。 pad_batch
对LSTM进行填充。
用法:
model.fit_generator(
generator=generator(trainingFiles),
steps_per_epoch=steps,
epochs=num_epochs,
validation_data=[x_test, y_test],
verbose=1)