生成器从多个文件读取作为keras的输入的多处理

时间:2018-07-31 09:27:33

标签: python keras multiprocessing generator

我正在用keras训练神经网络,并希望通过多处理来加快我的预处理/数据增强。原则上,在workers=N中使用use_multiprocessing=Truefit_generator似乎很简单,但是在我的情况下,避免从并行生成器中获取相似数据是很棘手的。

我的数据保存在几个文件中,每个文件都有几百万条记录(直到您到达文件末尾时,才知道总数)。对于每个文件,生成器逐条记录,将记录处理为网络的正确输入/输出格式,并增加一些数据。没有唯一的ID,尽管我想我可以即时创建一个。

我想知道并行创建多个生成器,每个生成器分别处理一个单独的文件列表是否最简单。我实际上并没有批量使用所有数据,因此,如果一个生成器在其文件列表的开始之前先于其他生成器重新启动,则实际上并不重要。如果在生成器中我可以访问工人编号(从1到N),则很容易做到。

1 个答案:

答案 0 :(得分:0)

我不确定如何实施您的建议。一个更高级的解决方案是实例化tf.data.TextLineDataset,它可以处理多个文本文件。为了以此训练Keras模型,您必须将iterator的输出绑定到模型的Input张量。遵循以下原则:

import tensorflow as tf 
# Parsing, augmentation etc
def __parse_record(record):
    ...
    return parsed_record

# Construct a TextLineDataset
ds = tf.data.TextLineDataset(filenames).map(_parse_record)
ds.shuffle().batch(batch_size) # Shuffle and batch

# Turn into an iterator
iterator = tf.data.Iterator.from_structure(ds.output_types, ds.output_shapes)
ds_init = iterator.make_initializer(ds)
# The iterator will yield inputs and labels
x,y = iterator.get_next()

# Tie output of iterator into Input of keras model via the tensor argument
model_input = Input(tensor=x)
# ... model definition

# Upon compiling the model specify target tensors
model.compile(loss, optimizer, target_tensors=[y])

# Now you can use model.fit() instead of fit_generator()
with K.get_session() as sess:
    sess.run(ds_init)
    model.fit(epochs, steps_per_epoch)

这应该训练得很快,但是它带来了一些缺点。根据相关的Keras example

  

输入张量也有重要的缺点。在   特别是输入张量固定在模型构造上   因为尚不支持重新布线网络。   因此,更改数据输入源意味着   必须保存模型权重并重建模型   从头开始连接新的输入数据。   当前无法作为培训进行验证   进步,必须在训练完成后再进行。