我正在使用以下代码进行机器学习(对于python和pytorch来说我还是很新的)。基本上,我认为问题是由于某些原因多任务处理没有发生。
我正在从这里使用代码:https://raw.githubusercontent.com/harryhan618/LaneNet/master/demo_test.py
该代码的目的是在图像上绘制车道标记。
import cv2
import torch
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
from lane_files.model import LaneNet
from lane_files.utils.transforms import *
from lane_files.utils.postprocess import embedding_post_process
if __name__=='__main__':
net = LaneNet(pretrained=False, embed_dim=7, delta_v=.5, delta_d=3.)
transform = Compose(Resize((800, 288)), ToTensor(),
Normalize(mean=(0.3598, 0.3653, 0.3662), std=(0.2573, 0.2663, 0.2756)))
img = cv2.imread('data/train_images/frame0.png')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB for net model input
x = transform(img)[0]
x.unsqueeze_(0)
save_dict = torch.load('lane_files/experiments/exp0/exp0_best.pth', map_location='cpu')
net.load_state_dict(save_dict['net'])
net.eval()
output = net(x)
embedding = output['embedding']
embedding = embedding.detach().cpu().numpy()
embedding = np.transpose(embedding[0], (1, 2, 0))
binary_seg = output['binary_seg']
bin_seg_prob = binary_seg.detach().cpu().numpy()
bin_seg_pred = np.argmax(bin_seg_prob, axis=1)[0]
seg_img = np.zeros_like(img)
lane_seg_img = embedding_post_process(embedding, bin_seg_pred, 0.5)
color = np.array([[255, 125, 0], [0, 255, 0], [0, 0, 255], [0, 255, 255]], dtype='uint8')
for i, lane_idx in enumerate(np.unique(lane_seg_img)):
seg_img[lane_seg_img == lane_idx] = color[i]
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
img = cv2.resize(img, (800, 288))
img = cv2.addWeighted(src1=seg_img, alpha=0.8, src2=img, beta=1., gamma=0.)
cv2.imshow("", img)
cv2.waitKey(5000)
cv2.destroyAllWindows()
预期结果:显示带有车道标记的图像 实际结果:
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "C:/Users/sarim/PycharmProjects/thesis/pytorch_learning.py", line 36, in <module>
lane_seg_img = embedding_post_process(embedding, bin_seg_pred, 0.5)
File "C:\Users\sarim\PycharmProjects\thesis\lane_files\utils\postprocess.py", line 29, in embedding_post_process
mean_shift.fit(embedding_reshaped)
File "C:\Users\sarim\AppData\Local\Programs\Python\Python37\lib\site-packages\sklearn\cluster\mean_shift_.py", line 424, in fit
cluster_all=self.cluster_all, n_jobs=self.n_jobs)
File "C:\Users\sarim\AppData\Local\Programs\Python\Python37\lib\site-packages\sklearn\cluster\mean_shift_.py", line 204, in mean_shift
(seed, X, nbrs, max_iter) for seed in seeds)
File "C:\Users\sarim\AppData\Local\Programs\Python\Python37\lib\site-packages\joblib\parallel.py", line 934, in __call__
self.retrieve()
File "C:\Users\sarim\AppData\Local\Programs\Python\Python37\lib\site-packages\joblib\parallel.py", line 833, in retrieve
self._output.extend(job.get(timeout=self.timeout))
File "C:\Users\sarim\AppData\Local\Programs\Python\Python37\lib\site-packages\joblib\_parallel_backends.py", line 521, in wrap_future_result
return future.result(timeout=timeout)
File "C:\Users\sarim\AppData\Local\Programs\Python\Python37\lib\concurrent\futures\_base.py", line 435, in result
return self.__get_result()
File "C:\Users\sarim\AppData\Local\Programs\Python\Python37\lib\concurrent\futures\_base.py", line 384, in __get_result
raise self._exception
File "C:\Users\sarim\AppData\Local\Programs\Python\Python37\lib\site-packages\joblib\externals\loky\_base.py", line 625, in _invoke_callbacks
callback(self)
File "C:\Users\sarim\AppData\Local\Programs\Python\Python37\lib\site-packages\joblib\parallel.py", line 309, in __call__
self.parallel.dispatch_next()
File "C:\Users\sarim\AppData\Local\Programs\Python\Python37\lib\site-packages\joblib\parallel.py", line 731, in dispatch_next
if not self.dispatch_one_batch(self._original_iterator):
File "C:\Users\sarim\AppData\Local\Programs\Python\Python37\lib\site-packages\joblib\parallel.py", line 759, in dispatch_one_batch
self._dispatch(tasks)
File "C:\Users\sarim\AppData\Local\Programs\Python\Python37\lib\site-packages\joblib\parallel.py", line 716, in _dispatch
job = self._backend.apply_async(batch, callback=cb)
File "C:\Users\sarim\AppData\Local\Programs\Python\Python37\lib\site-packages\joblib\_parallel_backends.py", line 510, in apply_async
future = self._workers.submit(SafeFunction(func))
File "C:\Users\sarim\AppData\Local\Programs\Python\Python37\lib\site-packages\joblib\externals\loky\reusable_executor.py", line 151, in submit
fn, *args, **kwargs)
File "C:\Users\sarim\AppData\Local\Programs\Python\Python37\lib\site-packages\joblib\externals\loky\process_executor.py", line 1022, in submit
raise self._flags.broken
joblib.externals.loky.process_executor.BrokenProcessPool: A task has failed to un-serialize. Please ensure that the arguments of the function are all picklable.