我希望在以较大2D阵列的位置(x,y)为中心的WxW窗口内找到最接近的匹配NxN块。下面的代码工作正常,但我的需求非常慢,因为我需要多次运行此操作。有一个更好的方法吗?? 这里N = 3,W = 15,x = 15,y = 15和(bestx,besty)是最佳匹配块的中心
import numpy as np
## Generate some test data
CurPatch = np.random.randint(20, size=(3, 3))
Data = np.random.randint(20,size=(30,30))
# Current Location
x,y = 15,15
# Initialise Best Match
bestcost = 999.0
bestx = 0;besty=0
for Wy in xrange(-7,8):
for Wx in xrange(-7,8):
Ywj,Ywi = y+Wy,x+Wx
cost = 0.0
for py in xrange(3):
for px in xrange(3):
cost += abs(Data[Ywj+py-1,Ywi+px-1] - CurPatch[py,px])
if cost < bestcost:
bestcost = cost
besty,bestx = Ywj,Ywi
print besty,bestx
答案 0 :(得分:0)
正如我在评论中所说,您可以检查cost
中bestcost
是否大于或等于for px in xrange(3):
,如果是这样,您可以中断,这样可以节省大量不必要的迭代
示例(将更改灯以强调更大迭代的差异):
import numpy as np
import time
## Generate some test data
CurPatch = np.random.randint(100, size=(3, 3))
Data = np.random.randint(100, size=(3000,3000))
# Current Location
x,y = 10, 10
# Initialise Best Match
bestcost = 999.0
bestx = 0;besty=0
t0 = time.time()
for Wy in xrange(-7,50):
for Wx in xrange(-7,50):
Ywj, Ywi = y+Wy, x+Wx
cost = 0.0
for py in xrange(3):
for px in xrange(3):
cost += abs(Data[Ywj+py-1,Ywi+px-1] - CurPatch[py,px])
if cost >= bestcost:
break
if cost < bestcost:
bestcost = cost
besty,bestx = Ywj,Ywi
print besty, bestx
print "time: {}".format(time.time() - t0)
时间是26毫秒
时间:0.0269999504089
没有中断的代码将输出37毫秒:
时间:0.0379998683929
此外,我还建议将此代码转换为函数。
答案 1 :(得分:0)
要感受速度,其中w与大窗口大小相同的子问题使用numpy更快(更简洁):
a= '''import numpy as np
## Generate some test data
CurPatch = np.random.randint(20, size=(3, 3))
Data = np.random.randint(20,size=(30,30))
def best(CurPatch,Data):
# Current Location
x,y = 15,15
# Initialise Best Match
bestcost = 999.0
bestx = 0;besty=0
for Wy in xrange(-14,14):
for Wx in xrange(-14,14):
Ywj,Ywi = y+Wy,x+Wx
cost = 0.0
for py in xrange(3):
for px in xrange(3):
cost += (Data[Ywj+py-1,Ywi+px-1] - CurPatch[py,px])**2
if cost < bestcost:
bestcost = cost
besty,bestx = Ywj,Ywi
return besty,bestx,bestcost
def minimize(CurPatch,W):
max_sum=999
s= CurPatch.shape[0]
S= W.shape[0]
for i in range(0,S-s):
for j in range(0,S-s):
running= np.sum(np.square((W[i:i+3,j:j+3]-CurPatch)))
if running<max_sum:
max_sum=running
x=i+1;y=j+1
return x,y,max_sum
'''
import timeit
print min(timeit.Timer('minimize(CurPatch,Data)', a).repeat(7, 10))
print min(timeit.Timer('best(CurPatch,Data)', a).repeat(7, 10))