我有以下函数返回计算最近邻居的数组:
def p_batch(U,X,Y):
return [nearest(u,X,Y) for u in U]
我想用numpy替换for循环。我一直在研究numpy.vectorize(),因为这似乎是正确的方法,但我无法让它工作。这是我到目前为止所尝试的:
def n_batch(U,X,Y):
vbatch = np.vectorize(nearest)
return vbatch(U,X,Y)
任何人都可以在我出错的地方给我一个暗示吗?
修改
最近的实施:
def nearest(u,X,Y):
return Y[np.argmin(np.sqrt(np.sum(np.square(np.subtract(u,X)),axis=1)))]
U,X,Y的函数(M = 20,N = 100,d = 50):
U = numpy.random.mtrand.RandomState(123).uniform(0,1,[M,d])
X = numpy.random.mtrand.RandomState(456).uniform(0,1,[N,d])
Y = numpy.random.mtrand.RandomState(789).randint(0,2,[N])
答案 0 :(得分:2)
方法#1
您可以使用Scipy's cdist
生成所有这些欧几里德距离,然后只需使用argmin
并将其编入Y
-
from scipy.spatial.distance import cdist
out = Y[cdist(U,X).argmin(1)]
示例运行 -
In [76]: M,N,d = 5,6,3
...: U = np.random.mtrand.RandomState(123).uniform(0,1,[M,d])
...: X = np.random.mtrand.RandomState(456).uniform(0,1,[N,d])
...: Y = np.random.mtrand.RandomState(789).randint(0,2,[N])
...:
# Using a loop comprehension to verify values
In [77]: [nearest(U[i], X,Y) for i in range(len(U))]
Out[77]: [1, 0, 0, 1, 1]
In [78]: Y[cdist(U,X).argmin(1)]
Out[78]: array([1, 0, 0, 1, 1])
方法#2
sklearn.metrics.pairwise_distances_argmin_min
直接向我们提供argmin
索引的另一种方式 -
from sklearn.metrics import pairwise
Y[pairwise.pairwise_distances_argmin(U,X)]
使用M=20,N=100,d=50
-
In [90]: M,N,d = 20,100,50
...: U = np.random.mtrand.RandomState(123).uniform(0,1,[M,d])
...: X = np.random.mtrand.RandomState(456).uniform(0,1,[N,d])
...: Y = np.random.mtrand.RandomState(789).randint(0,2,[N])
...:
在cdist
和pairwise_distances_argmin
-
In [91]: %timeit cdist(U,X).argmin(1)
10000 loops, best of 3: 55.2 µs per loop
In [92]: %timeit pairwise.pairwise_distances_argmin(U,X)
10000 loops, best of 3: 90.6 µs per loop
针对循环版本的计时 -
In [93]: %timeit [nearest(U[i], X,Y) for i in range(len(U))]
1000 loops, best of 3: 298 µs per loop
In [94]: %timeit Y[cdist(U,X).argmin(1)]
10000 loops, best of 3: 55.6 µs per loop
In [95]: %timeit Y[pairwise.pairwise_distances_argmin(U,X)]
10000 loops, best of 3: 91.1 µs per loop
In [96]: 298.0/55.6 # Speedup with cdist over loopy one
Out[96]: 5.359712230215827