如何针对nn预测优化此代码?

时间:2016-09-21 20:00:38

标签: python performance numpy scipy nearest-neighbor

如何优化此代码? 目前它正在运行以减慢通过此循环的数据量。此代码运行1个最近邻居。它将根据p_data_set

预测training_element的标签
#               [x] ,           [[x1],[x2],[x3]],    [l1, l2, l3]
def prediction(training_element, p_data_set, p_label_set):
    temp = np.array([], dtype=float)
    for p in p_data_set:
        temp = np.append(temp, distance.euclidean(training_element, p))

    minIndex = np.argmin(temp)
    return p_label_set[minIndex]

3 个答案:

答案 0 :(得分:2)

使用k-D tree进行快速最近邻居查找,例如scipy.spatial.cKDTree

from scipy.spatial import cKDTree

# I assume that p_data_set is (nsamples, ndims)
tree = cKDTree(p_data_set)

# training_elements is also assumed to be (nsamples, ndims)
dist, idx = tree.query(training_elements, k=1)

predicted_labels = p_label_set[idx]

答案 1 :(得分:1)

您可以使用distance.cdist直接获取距离temp,然后使用.argmin()获取min-index,就像这样 -

minIndex = distance.cdist(training_element[None],p_data_set).argmin()

以下是使用np.einsum -

的替代方法
subs = p_data_set - training_element
minIndex =  np.einsum('ij,ij->i',subs,subs).argmin()

运行时测试

好吧,我认为cKDTree会轻易击败cdist,但我认为training_element1D数组对于cdist并不太重我看到它以{strong> cKDTree 良好的优势击败10x+而非!

这是时间结果 -

In [422]: # Setup arrays
     ...: p_data_set = np.random.randint(0,9,(40000,100))
     ...: training_element = np.random.randint(0,9,(100,))
     ...: 

In [423]: def tree_based(p_data_set,training_element): #@ali_m's soln
     ...:     tree = cKDTree(p_data_set)
     ...:     dist, idx = tree.query(training_element, k=1)
     ...:     return idx
     ...: 
     ...: def einsum_based(p_data_set,training_element):    
     ...:     subs = p_data_set - training_element
     ...:     return np.einsum('ij,ij->i',subs,subs).argmin()
     ...: 

In [424]: %timeit tree_based(p_data_set,training_element)
1 loops, best of 3: 210 ms per loop

In [425]: %timeit einsum_based(p_data_set,training_element)
100 loops, best of 3: 17.3 ms per loop

In [426]: %timeit distance.cdist(training_element[None],p_data_set).argmin()
100 loops, best of 3: 14.8 ms per loop

答案 2 :(得分:0)

如果使用得当,Python可以是非常快速的编程语言。 这是我的建议(faster_prediction):

import numpy as np
import time

def euclidean(a,b):
    return np.linalg.norm(a-b)

def prediction(training_element, p_data_set, p_label_set):
    temp = np.array([], dtype=float)
    for p in p_data_set:
        temp = np.append(temp, euclidean(training_element, p))

    minIndex = np.argmin(temp)
    return p_label_set[minIndex]

def faster_prediction(training_element, p_data_set, p_label_set):    
    temp = np.tile(training_element, (p_data_set.shape[0],1))
    temp = np.sqrt(np.sum( (temp - p_data_set)**2 , 1))    

    minIndex = np.argmin(temp)
    return p_label_set[minIndex]   


training_element = [1,2,3]
p_data_set = np.random.rand(100000, 3)*10
p_label_set = np.r_[0:p_data_set.shape[0]]


t1 = time.time()
result_1 = prediction(training_element, p_data_set, p_label_set)
t2 = time.time()

t3 = time.time()
result_2 = faster_prediction(training_element, p_data_set, p_label_set)
t4 = time.time()


print "Execution time 1:", t2-t1, "value: ", result_1
print "Execution time 2:", t4-t3, "value: ", result_2
print "Speed up: ", (t4-t3) / (t2-t1)

我在相当旧的笔记本电脑上得到以下结果:

Execution time 1: 21.6033108234 value:  9819
Execution time 2: 0.0176379680634 value:  9819
Speed up:  1224.81857013

这让我觉得我一定做了一些愚蠢的错误:)

如果数据非常庞大,内存可能会成为一个问题,我建议使用Cython或在C ++中实现函数并将其包装在python中。