我无法使用我的函数运行joblib,该函数采用numpy数组,训练有素的Keras模型列表和字符串列表作为参数。
我尝试将参数创建为namedtuple,甚至创建为具有不可变属性的类。有什么想法吗?
Params = collections.namedtuple('Params',['inputs','y_list','trained_models'])
p = Params(inputs, y_list, trained_models)
或
class Params:
def __init__(self, inputs, mq_list,trained_models):
super(Params , self).__setattr__("inputs", inputs)
super(Params , self).__setattr__("y_list", y_list)
super(Params , self).__setattr__("trained_models", trained_models)
我喜欢并行运行的功能:
def predict(params):
inputs = params.inputs
y_list = params.y_list
trained_models = params.trained_models
# process and vectorize inputs
X= new_X(inputs)
X_vect= vect.transform(X)
predictions = dict()
for y in y_list:
y_field = trained_models[y].predict(X_vect)
# evaluate model
if y_field[0] > 0.05:
return None, None
predictions[y] = y_field[0]
return X, predictions
并行调用函数:
r= Parallel(n_jobs=4, verbose=5)(
delayed(predict)(p)
for c in range(100))
错误:
TypeErrorTraceback (most recent call last) <timed exec> in <module>()
~/.conda/envs/mlgpu/lib/python3.6/site-packages/joblib/parallel.py in
__call__(self, iterable)
787 # consumption.
788 self._iterating = False
--> 789 self.retrieve()
790 # Make sure that we get a last message telling us we are done
791 elapsed_time = time.time() - self._start_time
~/.conda/envs/mlgpu/lib/python3.6/site-packages/joblib/parallel.py in retrieve(self)
697 try:
698 if getattr(self._backend, 'supports_timeout', False):
--> 699 self._output.extend(job.get(timeout=self.timeout))
700 else:
701 self._output.extend(job.get())
~/.conda/envs/mlgpu/lib/python3.6/multiprocessing/pool.py in get(self, timeout)
642 return self._value
643 else:
--> 644 raise self._value
645
646 def _set(self, i, obj):
~/.conda/envs/mlgpu/lib/python3.6/multiprocessing/pool.py in
_handle_tasks(taskqueue, put, outqueue, pool, cache)
422 break
423 try:
--> 424 put(task)
425 except Exception as e:
426 job, idx = task[:2]
~/.conda/envs/mlgpu/lib/python3.6/site-packages/joblib/pool.py in send(obj)
369 def send(obj):
370 buffer = BytesIO()
--> 371 CustomizablePickler(buffer, self._reducers).dump(obj)
372 self._writer.send_bytes(buffer.getvalue())
373 self._send = send
TypeError: can't pickle _thread.lock objects
答案 0 :(得分:0)
您应该创建自己的类,因为您不知道函数collections.namedtuple
是否具有不可拾取的部分。
几个月前,我遇到了类似的问题,当时我在类中添加了lambda函数以将其作为参数传递。但是由于lambda函数不能被包pickle
拾取,所以会出现错误。