为了使用PyStruct执行图像分割(通过推理[1]),我首先需要构建一个图形,其节点对应像素,边缘是这些像素之间的链接。
因此我编写了一个函数来实现:
def create_graph_for_pystruct(mrf, betas, nb_labels):
M, N = mrf.shape
b1, b2, b3, b4 = betas
edges = []
pairwise = np.zeros((nb_labels, nb_labels))
# loop over rows
for i in range(M):
# loop over columns
for j in range(N):
# get rid of pixels belonging to image's borders
if i!=0 and i!=M-1 and j!=0 and j!=N-1:
# get the current linear index
current_linear_ind = i * N + j
# retrieve its neighborhood (yield a list of tuple (row, col))
neigh = np.array(getNeighborhood(i, j, M, N))
# convert neighbors indices to linear ones
neigh_linear_ind = neigh[:, 0] * N + neigh[:, 1]
# add edges
[edges.append((current_linear_ind, n)) for n in neigh_linear_ind]
mat1 = b1 * np.eye(nb_labels)
mat2 = b2 * np.eye(nb_labels)
mat3 = b3 * np.eye(nb_labels)
mat4 = b4 * np.eye(nb_labels)
pairwise = np.ma.dstack((pairwise, mat1, mat1, mat2, mat2, mat3, mat3, mat4, mat4))
return np.array(edges), pairwise[:, :, 1:]
然而,它很慢,我想知道在哪里可以改进我的功能,以加快它。 [1] https://pystruct.github.io/generated/pystruct.inference.inference_dispatch.html
答案 0 :(得分:2)
这是一个应该运行得更快的代码建议(在numpy中应该专注于对for循环使用向量化)。我尝试使用向量化在单个传递中构建整个输出,我使用有用的np.ogrid
来生成xy坐标。
def new(mrf, betas, nb_labels):
M, N = mrf.shape
b1, b2, b3, b4 = betas
mat1,mat2,mat3,mat4 = np.array([b1,b2,b3,b4])[:,None,None]*np.eye(nb_labels)[None,:,:]
pairwise = np.array([mat1, mat1, mat2, mat2, mat3, mat3, mat4, mat4]*((M-2)*(N-2))).transpose()
m,n=np.ogrid[0:M,0:N]
a,b,c= m[0:-2]*N+n[:,0:-2],m[1:-1]*N+n[:,0:-2],m[2: ]*N+n[:,0:-2]
d,e,f= m[0:-2]*N+n[:,1:-1],m[1:-1]*N+n[:,1:-1],m[2: ]*N+n[:,1:-1]
g,h,i= m[0:-2]*N+n[:,2: ],m[1:-1]*N+n[:,2: ],m[2: ]*N+n[:,2: ]
center_index = e
edges_index = np.stack([a,b,c,d,f,g,h,i])
edges=np.empty(list(edges_index.shape)+[2])
edges[:,:,:,0]= center_index[None,:,:]
edges[:,:,:,1]= edges_index
edges=edges.reshape(-1,2)
return edges,pairwise
时间和比较测试:
import timeit
args=(np.empty((40,50)), [1,2,3,4], 10)
f1=lambda : new(*args)
f2=lambda : create_graph_for_pystruct(*args)
edges1, pairwise1 = f1()
edges2, pairwise2 = f2()
#outputs are not exactly indentical: the order isn't the the same
#I sort both to compare the results
edges1 = edges1[np.lexsort(np.fliplr(edges1).T)]
edges2 = edges2[np.lexsort(np.fliplr(edges2).T)]
print("edges identical ?",(edges1 == edges2).all())
print("pairwise identical ?",(pairwise1 == pairwise2).all())
print("new : ",timeit.timeit(f1,number=1))
print("old : ",timeit.timeit(f2,number=1))
输出
edges identical ? True
pairwise identical ? True
new : 0.015270026000507642
old : 4.611805051001284
注意:我必须猜测getNeighborhood
函数