使用keras.utils.Sequence多处理和数据库 - 何时连接?

时间:2018-04-17 13:41:40

标签: mongodb tensorflow keras python-multiprocessing

我正在使用带有Tensorflow后端的Keras训练神经网络。数据集不适合RAM,因此,我将其存储在Mongo数据库中,并使用keras.utils.Sequence的子类检索批次。

如果我使用model.fit_generator()运行use_multiprocessing=False,一切正常。

当我打开多处理时,我会在产生工作人员或连接到数据库时遇到错误。

如果我在__init__中创建了一个连接,我就会遇到一个例外,其中的文字说明了酸洗锁定对象中的错误。对不起,我不记得了。但培训甚至没有开始。

如果我在__get_item__中创建连接,则训练开始并运行一些时期,然后我收到错误[WinError 10048] Only one usage of each socket address (protocol/network address/port) is normally permitted

根据the pyMongo manuals,它不是fork安全的,每个子进程必须创建自己的数据库连接。我使用不使用分叉的Windows,而是生成进程,但是,差异并不重要,恕我直言。

这解释了为什么无法在__init__中进行连接。

以下是docs的另一个引用:

  

为每个流程创建一次此客户端,并将其重用于所有操作。为每个请求创建一个新客户端是一个常见的错误,这是非常低效的。

这解释了__get_item__中的错误。

然而,目前尚不清楚,我的班级如何理解Keras已经创造了新的过程。

这是我的Sequence实现的最后一个变体的伪代码(每个请求上的新连接):

import pymongo
import numpy as np
from keras.utils import Sequence
from keras.utils.np_utils import to_categorical

class MongoSequence(Sequence):
    def __init__(self, train_set, batch_size, server=None, database="database", collection="full_set"):
        self._train_set = train_set
        self._server = server
        self._db = database
        self.collection = collection
        self._batch_size = batch_size

        query = {}  # train_set query
        self._object_ids = [ smp["_id"] for uid in train_set for smp in self._connect().find(query, {'_id': True})]

    def _connect(self):
        client = pymongo.MongoClient(self._server)
        db = self._client[self._db]
        return _db[self._collection]

    def __len__(self):
        return int(np.ceil(len(self._object_ids) / float(self._batch_size)))

    def __getitem__(self, item):
        oids = self._object_ids[item * self._batch_size: (item+1) * self._batch_size]
        X = np.empty((len(oids), IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS), dtype=np.float32)
        y = np.empty((len(oids), 2), dtype=np.float32)
        for i, oid in enumerate(oids):
            smp = self._connect().find({'_id': oid}).next()
            X[i, :, :, :] = pickle.loads(smp['frame']).astype(np.float32)
            y[i] = to_categorical(not smp['result'], 2)
        return X, y

也就是说,在对象构造上,我根据标准检索所有相关的ObjectIDs形成火车组。在调用__getitem__时,将从数据库中检索实际对象。他们的ObjectIDs是从列表切片中确定的。

这段调用model.fit_generator(generator=MongoSequence(train_ids, batch_size=10), ... )的代码会生成几个python进程,每个进程根据日志消息初始化Tensorflow后端,并开始训练。

但最后,该函数抛出异常,称为connect,位于pymongo内部。

不幸的是,我还没有存储调用堆栈。上面描述了错误,我再说一遍:[WinError 10048] Only one usage of each socket address (protocol/network address/port) is normally permitted

我的假设是此代码与服务器建立的连接太多,因此在__getitem__中连接是错误的。

构造函数中的连接也是错误的,因为它是在主进程中执行的,而Mongo文档直接对它进行反对。

Sequence班中还有一种方法on_epoch_end。但是,我需要在纪元开始时连接,而不是结束。

引自Keras docs:

  

如果您想在世纪之间修改数据集,可以实施on_epoch_end

那么,有什么建议吗?文档在这里不是很具体。

2 个答案:

答案 0 :(得分:1)

看起来我找到了解决方案。解决方案是 - 跟踪进程ID并在其更改时重新连接

class MongoSequence(Sequence):
    def __init__(self, batch_size, train_set, query=None, server=None, database="database", collection="full_set"):
        self._server = server
        self._db = database
        self._collection_name = collection
        self._batch_size = batch_size
        self._query = query
        self._collection = self._connect()

        self._object_ids = [ smp["_id"] for uid in train_set for smp in self._collection.find(self._query, {'_id': True})]

        self._pid = os.getpid()
        del self._collection   #  to be sure, that we've disconnected
        self._collection = None

    def _connect(self):
        client = pymongo.MongoClient(self._server)
        db = self._client[self._db]
        return db[self._collection_name]

    def __len__(self):
        return int(np.ceil(len(self._object_ids) / float(self._batch_size)))

    def __getitem__(self, item):
        if self._collection is None or self._pid != os.getpid():
            self._collection = self._connect()
            self._pid = os.getpid()

        oids = self._object_ids[item * self._batch_size: (item+1) * self._batch_size]
        X = np.empty((len(oids), IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS), dtype=np.float32)
        y = np.empty((len(oids), 2), dtype=np.float32)
        for i, oid in enumerate(oids):
            smp = self._connect().find({'_id': oid}).next()
            X[i, :, :, :] = pickle.loads(smp['frame']).astype(np.float32)
            y[i] = to_categorical(not smp['result'], 2)
        return X, y

答案 1 :(得分:0)

on_epoch_end()中创建您的连接,并通过' init ()'方法明确调用on_epoch_end()。这使得on_epoch_end()在实践中起作用,好像ti“在时代开始”。 (每个纪元的结束,是下一个纪元的开始。第一个纪元在它之前没有一个纪元,因此在初始化中显式调用。)