Keras fit_generator()
的参数pickle_safe
默认为False
如果是 pickle_safe,训练可以更快地运行,并相应地将标志设置为True
?
根据Kera's docs:
pickle_safe :如果为True,请使用基于流程的线程。请注意,因为此实现依赖于多处理,所以不应将非可选参数传递给生成器,因为它们无法轻松传递给子进程。
我不明白这是什么意思。
如何判断我的参数是pickle_safe
还是 ??
如果相关:
- 我正在传递自定义发电机
- 生成器函数接受参数:X_train,y_train,batch_size,p_keep;
它们的类型为np.array,int,float)
- 我没有使用GPU
- 此外,我正在使用Keras 1.2.1,但我相信这个论点与keras 2中的行为相同
答案 0 :(得分:6)
我不熟悉keras
,但是从文档中可以看出,pickle_safe
只意味着您的生成器生成的元组必须是“可选择的”。
pickle
是一个标准的python模块,用于序列化和反序列化对象。标准multiprocessing
实现使用pickle
机制在不同进程之间共享对象 - 因为这两个进程不共享相同的地址空间,所以它们不能直接看到相同的python对象。因此,要将对象从进程A发送到进程B,它们在A中进行pickle(以特定的已知格式生成一系列字节),然后通过进程间通信机制将pickle格式发送到B,并且在B中打开,在B的地址空间中生成A原始对象的副本。
因此,要发现您的对象是否可以选择,只需在它们上调用pickle.dumps
即可。
>>> import pickle
>>> class MyObject:
... def __init__(self, a, b, c):
... self.a = a
... self.b = b
... self.c = c
...
>>> foo = MyObject(1, 2, 3)
>>> pickle.dumps(foo)
b'\x80\x03c__main__\nMyObject\nq\x00)\x81q\x01}q\x02(X\x01\x00\x00\x00cq\x03K\x03X\x01\x00\x00\x00aq\x04K\x01X\x01\x00\x00\x00bq\x05K\x02ub.'
>>>
dumps
生成一个字节字符串。我们现在可以使用foo
将bar
对象从字节字符串重构为loads
:
>>> foo_pick = pickle.dumps(foo)
>>> bar = pickle.loads(foo_pick)
>>> bar
<__main__.MyObject object at 0x7f5e262ece48>
>>> bar.a, bar.b, bar.c
(1, 2, 3)
如果某些东西不可挑选,你会得到一个例外。例如,lambdas不能被腌制:
>>> class MyOther:
... def __init__(self, a, b, c):
... self.a = a
... self.b = b
... self.c = c
... self.printer = lambda: print(self.a, self.b, self.c)
...
>>> other = MyOther(1, 2, 3)
>>> other_pick = pickle.dumps(other)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
AttributeError: Can't pickle local object 'MyOther.__init__.<locals>.<lambda>'
有关详细信息,请参阅文档: https://docs.python.org/3.5/library/pickle.html?highlight=pickle#what-can-be-pickled-and-unpickled