我正在尝试处理Keras的大型训练数据集。
我将model.fit_generator
与自定义生成器一起使用,该生成器从SQL文件中读取数据。
我收到一条错误消息,告诉我我不能在两个不同的线程中使用SQLite对象:
ProgrammingError: SQLite objects created in a thread can only be used in that
same thread.The object was created in thread id 140736714019776 and this is
thread id 123145449209856
我尝试对HDF5文件执行相同的操作,并遇到了一个分段错误,我现在认为它也与fit_generator
的多线程字符有关(请参阅错误报告here)。
使用这些生成器的正确方法是什么,因为我认为必须从文件中批量读取不适合内存的数据集的数据。
以下是生成器的代码:
class DataGenerator:
def __init__(self, inputfile, batch_size, **kwargs):
self.inputfile = inputfile
self.batch_size = batch_size
def generate(self, labels, idlist):
while 1:
for batch in self._read_data_from_hdf(idlist):
batch = pandas.merge(batch, labels, how='left', on=['id'])
Y = batch['label']
X = batch.drop(['id', 'label'], axis=1)
yield (X, Y)
def _read_data_from_hdf(self, idlist):
chunklist = [idlist[i:i + self.batch_size] for i in range(0, len(idlist), self.batch_size)]
for chunk in chunklist:
yield pandas.read_hdf(self.inputfile, key='data', where='id in {}'.format(chunk))
# [...]
model.fit_generator(generator=training_generator,
steps_per_epoch=len(partitions['train']) //
config['batch_size'],
validation_data=validation_generator,
validation_steps=len(partitions['validation']) //
config['batch_size'],
epochs=config['epochs'])
请参阅full example repository here。
感谢您的支持。
干杯,
本
答案 0 :(得分:1)
面对同样的问题,我通过将线程安全装饰器与sqlalchemy
引擎相结合来找出解决方案,该引擎可以管理对数据库的并发访问:
import pandas
from sqlalchemy import create_engine
class threadsafe_iter:
def __init__(self, it):
self.it = it
self.lock = threading.Lock()
def __iter__(self):
return self
def __next__(self):
with self.lock:
return next(self.it)
def threadsafe_generator(f):
def g(*a, **kw):
return threadsafe_iter(f(*a, **kw))
return g
class DataGenerator:
def __init__(self, inputfile, batch_size, **kwargs):
self.inputfile = inputfile
self.batch_size = batch_size
self.sqlengine = create_engine('sqlite:///' + self.inputfile)
def __del__(self):
self.sqlengine.dispose()
@threadsafe_generator
def generate(self, labels, idlist):
while 1:
for batch in self._read_data_from_sql(idlist):
Y = batch['label']
X = batch.drop(['id', 'label'], axis=1)
yield (X, Y)
def _read_data_from_sql(self, idlist):
chunklist = [idlist[i:i + self.batch_size]
for i in range(0, len(idlist), self.batch_size)]
for chunk in chunklist:
query = 'select * from data where id in {}'.format(tuple(chunk))
df = pandas.read_sql(query, self.sqlengine)
yield df
# Build keras model and instantiate generators
model.fit_generator(generator=training_generator,
steps_per_epoch=train_steps,
validation_data=validation_generator,
validation_steps=valid_steps,
epochs=10,
workers=4)
我希望有所帮助!