我正在使用带有Tensorflow后端的Keras训练神经网络。数据集不适合RAM,因此,我将其存储在Mongo数据库中,并使用keras.utils.Sequence
的子类检索批次。
如果我使用model.fit_generator()
运行use_multiprocessing=False
,一切正常。
当我打开多处理时,我会在产生工作人员或连接到数据库时遇到错误。
如果我在__init__
中创建了一个连接,我就会遇到一个例外,其中的文字说明了酸洗锁定对象中的错误。对不起,我不记得了。但培训甚至没有开始。
如果我在__get_item__
中创建连接,则训练开始并运行一些时期,然后我收到错误[WinError 10048] Only one usage of each socket address (protocol/network address/port) is normally permitted
。
根据the pyMongo manuals,它不是fork安全的,每个子进程必须创建自己的数据库连接。我使用不使用分叉的Windows,而是生成进程,但是,差异并不重要,恕我直言。
这解释了为什么无法在__init__
中进行连接。
以下是docs的另一个引用:
为每个流程创建一次此客户端,并将其重用于所有操作。为每个请求创建一个新客户端是一个常见的错误,这是非常低效的。
这解释了__get_item__
中的错误。
然而,目前尚不清楚,我的班级如何理解Keras已经创造了新的过程。
这是我的Sequence实现的最后一个变体的伪代码(每个请求上的新连接):
import pymongo
import numpy as np
from keras.utils import Sequence
from keras.utils.np_utils import to_categorical
class MongoSequence(Sequence):
def __init__(self, train_set, batch_size, server=None, database="database", collection="full_set"):
self._train_set = train_set
self._server = server
self._db = database
self.collection = collection
self._batch_size = batch_size
query = {} # train_set query
self._object_ids = [ smp["_id"] for uid in train_set for smp in self._connect().find(query, {'_id': True})]
def _connect(self):
client = pymongo.MongoClient(self._server)
db = self._client[self._db]
return _db[self._collection]
def __len__(self):
return int(np.ceil(len(self._object_ids) / float(self._batch_size)))
def __getitem__(self, item):
oids = self._object_ids[item * self._batch_size: (item+1) * self._batch_size]
X = np.empty((len(oids), IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS), dtype=np.float32)
y = np.empty((len(oids), 2), dtype=np.float32)
for i, oid in enumerate(oids):
smp = self._connect().find({'_id': oid}).next()
X[i, :, :, :] = pickle.loads(smp['frame']).astype(np.float32)
y[i] = to_categorical(not smp['result'], 2)
return X, y
也就是说,在对象构造上,我根据标准检索所有相关的ObjectIDs
形成火车组。在调用__getitem__
时,将从数据库中检索实际对象。他们的ObjectIDs
是从列表切片中确定的。
这段调用model.fit_generator(generator=MongoSequence(train_ids, batch_size=10), ... )
的代码会生成几个python进程,每个进程根据日志消息初始化Tensorflow后端,并开始训练。
但最后,该函数抛出异常,称为connect
,位于pymongo
内部。
不幸的是,我还没有存储调用堆栈。上面描述了错误,我再说一遍:[WinError 10048] Only one usage of each socket address (protocol/network address/port) is normally permitted
。
我的假设是此代码与服务器建立的连接太多,因此在__getitem__
中连接是错误的。
构造函数中的连接也是错误的,因为它是在主进程中执行的,而Mongo文档直接对它进行反对。
Sequence
班中还有一种方法on_epoch_end
。但是,我需要在纪元开始时连接,而不是结束。
引自Keras docs:
如果您想在世纪之间修改数据集,可以实施on_epoch_end
那么,有什么建议吗?文档在这里不是很具体。
答案 0 :(得分:1)
看起来我找到了解决方案。解决方案是 - 跟踪进程ID并在其更改时重新连接
class MongoSequence(Sequence):
def __init__(self, batch_size, train_set, query=None, server=None, database="database", collection="full_set"):
self._server = server
self._db = database
self._collection_name = collection
self._batch_size = batch_size
self._query = query
self._collection = self._connect()
self._object_ids = [ smp["_id"] for uid in train_set for smp in self._collection.find(self._query, {'_id': True})]
self._pid = os.getpid()
del self._collection # to be sure, that we've disconnected
self._collection = None
def _connect(self):
client = pymongo.MongoClient(self._server)
db = self._client[self._db]
return db[self._collection_name]
def __len__(self):
return int(np.ceil(len(self._object_ids) / float(self._batch_size)))
def __getitem__(self, item):
if self._collection is None or self._pid != os.getpid():
self._collection = self._connect()
self._pid = os.getpid()
oids = self._object_ids[item * self._batch_size: (item+1) * self._batch_size]
X = np.empty((len(oids), IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS), dtype=np.float32)
y = np.empty((len(oids), 2), dtype=np.float32)
for i, oid in enumerate(oids):
smp = self._connect().find({'_id': oid}).next()
X[i, :, :, :] = pickle.loads(smp['frame']).astype(np.float32)
y[i] = to_categorical(not smp['result'], 2)
return X, y
答案 1 :(得分:0)
在on_epoch_end()
中创建您的连接,并通过' init ()'方法明确调用on_epoch_end()
。这使得on_epoch_end()
在实践中起作用,好像ti“在时代开始”。 (每个纪元的结束,是下一个纪元的开始。第一个纪元在它之前没有一个纪元,因此在初始化中显式调用。)