我想计算两组中所有元素组合之间的距离。
descriptor_1
(resp。descriptor_2
)是长度为N1(分别为N2)的2D数组列表(每个元素一个2D数组)。
要计算这两组之间的所有组合,我使用:
combi = list(itertools.product(descriptor_1, descriptor_2))
生成一个长度为N1*N2
的2-uples列表。
计算距离:
dist = map(chi2_dist, combi)
其中:
def chi2_dist(a, b):
a = a.flatten()
b = b.flatten()
dist = (1/2) * np.sum( (a-b)**2 / (a+b+EPS))
return dist
但是我收到以下错误:
TypeError: chi2_dist() takes exactly 2 arguments (1 given)
但是,由于我的元组包含2个元素,我不明白错误。
答案 0 :(得分:2)
你应该
def chi2_dist(ab):
a = ab[0]
b = ab[1]
a = a.flatten()
b = b.flatten()
顺便说一句,BTW,更有效率
map(chi2_dist, itertools.product(descriptor_1, descriptor_2))
不需要中间列表