我有一个Keras模型,该模型需要使用多个ImageGenerator
从多个来源馈送数据(但是该模型仍然只有1个输入)。
我创建了一个可以做到这一点的函数(实际上我使用5-6个生成器)
def multiple_generator(batch_size):
genX1 = train_datagen.flow_from_directory('./directory1',
target_size=(img_height, img_width),
batch_size=batch_size//2,
class_mode='categorical')
genX2 = train_datagen.flow_from_directory('./directory2',
target_size=(img_height, img_width),
batch_size=batch_size//2,
class_mode='categorical')
while True:
X1i = genX1.next()
X2i = genX2.next()
yield np.concatenate([X1i[0],X2i[0]],axis = 0),\
np.concatenate([X1i[1],X2i[1]],,axis = 0)
但是开始训练时,训练时间比使用单个发电机要长得多。例如,在单个生成器中,每个周期仅花费120秒,而与batch_size
无关,但是在使用multiple_generator
时,batch_size = 64
花费5分钟,而128花费每个周期12分钟。
我认为迭代多个生成器的任务可能会减慢训练时间,并且我认为可以并行执行以下功能:
def multiple_generator(batch_size):
pool = Pool(processes=2)
genX1 = pool.apply(train_datagen.flow_from_directory('./directory1',
target_size=(img_height, img_width),
batch_size=batch_size//2,
class_mode='categorical'))
genX2 = pool.apply(train_datagen.flow_from_directory('./directory2',
target_size=(img_height, img_width),
batch_size=batch_size//2,
class_mode='categorical'))
while True:
X1i = genX1.next()
X2i = genX2.next()
yield np.concatenate([X1i[0],X2i[0]],axis = 0),\
np.concatenate([X1i[1],X2i[1]],,axis = 0)
但是它返回错误
Traceback (most recent call last):
File "/home/cngc3/anaconda3/envs/tensorflow/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2961, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-26-2d1d4ddfacf1>", line 4, in <module>
class_mode='categorical'))
File "/home/cngc3/anaconda3/envs/tensorflow/lib/python3.6/multiprocessing/pool.py", line 259, in apply
return self.apply_async(func, args, kwds).get()
File "/home/cngc3/anaconda3/envs/tensorflow/lib/python3.6/multiprocessing/pool.py", line 644, in get
raise self._value
File "/home/cngc3/anaconda3/envs/tensorflow/lib/python3.6/multiprocessing/pool.py", line 424, in _handle_tasks
put(task)
File "/home/cngc3/anaconda3/envs/tensorflow/lib/python3.6/multiprocessing/connection.py", line 206, in send
self._send_bytes(_ForkingPickler.dumps(obj))
File "/home/cngc3/anaconda3/envs/tensorflow/lib/python3.6/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
TypeError: can't pickle _thread.lock objects
我在处理多进程方面没有太多经验,您对此有什么解决方案?始终欢迎采用其他加快发电机速度的策略。非常感谢