我有一个细分项目。我有图像和标签,它们保留了分割的基本事实。图像很大,并且包含很多"空"区域。 我想从图像和标签中剪切补丁,以便补丁中的标签非零。
我编写了以下代码,但速度非常慢。任何改进都将受到高度赞赏。
import numpy as np
import matplotlib.pyplot as plt
让我们创建虚拟数据
img = np.random.rand(300,200,3)
img[240:250,120:200]=0
mask = np.zeros((300,200))
mask[220:260,120:300]=0.7
mask[250:270,140:170]=0.3
f, axarr = plt.subplots(1,2, figsize = (10, 5))
axarr[0].imshow(img)
axarr[1].imshow(mask)[![enter image description here][1]][1]
plt.show()
我效率低下的代码:
IM_SIZE = 60 # Patch size
x_min, y_min = 0,0
x_max = img.shape[0] - IM_SIZE
y_max = img.shape[1] - IM_SIZE
xd, yd, x, y = 0,0,0,0
if (mask.max() > 0):
xd, yd = np.where(mask>0)
x_min = xd.min()
y_min = yd.min()
x_max = min(xd.max()- IM_SIZE-1, img.shape[0] - IM_SIZE-1)
y_max = min(yd.max()- IM_SIZE-1, img.shape[1] - IM_SIZE-1)
if (y_min >= y_max):
y = y_max
if (y + IM_SIZE >= img.shape[1] ):
print('Error')
else:
y = np.random.randint(y_min,y_max)
if (x_min>=x_max):
x = x_max
if (x+IM_SIZE >= img.shape[0] ):
print('Error')
else:
x = np.random.randint(x_min,x_max )
print(x,y)
img = img[x:x+IM_SIZE, y:y+IM_SIZE,:]
mask = mask[x:x+IM_SIZE, y:y+IM_SIZE]
f, axarr = plt.subplots(1,2, figsize = (10, 5))
axarr[0].imshow(img)
axarr[1].imshow(mask)
plt.show()
答案 0 :(得分:1)
大部分时间都是由mask.max()(对于某些加速可以更改为np.max(mask))和np.where(mask> 0)使用的。
如果您需要每次在不同的面具上使用where功能,请查看numexpr。或者您可以使用joblib通过并行运行许多此类案例来存储给定掩码的x / y_min / max结果。
使用numba.jit重新排列函数会给我带来更好的结果:
@jit
def temp(mask):
xd, yd = np.where(mask>0)
x_min = np.min(xd)
y_min = np.min(yd)
x_max = min(np.max(xd)- IM_SIZE-1, img.shape[0] - IM_SIZE-1)
y_max = min(np.max(yd)- IM_SIZE-1, img.shape[1] - IM_SIZE-1)
return x_min,x_max,y_min,y_max
def solver_new(img):
IM_SIZE = 60 # Patch size
x_min, y_min = 0,0
x_max = img.shape[0] - IM_SIZE
y_max = img.shape[1] - IM_SIZE
xd, yd, x, y = 0,0,0,0
if (np.max(mask) > 0):
x_min,x_max,y_min,y_max = temp(mask)
if (y_min >= y_max):
y = y_max
if (y + IM_SIZE >= img.shape[1] ):
print('Error')
else:
y = np.random.randint(y_min,y_max)
if (x_min>=x_max):
x = x_max
if (x+IM_SIZE >= img.shape[0] ):
print('Error')
else:
x = np.random.randint(x_min,x_max )
return x,y
由于图像和补丁尺寸很小,因此结果不太有意义,因为缓存对时间有很大影响。我在问题中发布的实现大约需要200us,而在此处发布的实现大约需要90us。