我对Java和Python中特定代码片段的性能有疑问。
算法:
我正在生成随机N维点,然后对于彼此在一定距离阈值之下的所有点,我做了一些处理。处理本身并不重要,因为它不会影响总执行时间。在两种情况下生成点也需要几分之一秒,所以我只对进行比较的部分感兴趣。
执行时间:
对于3000点和2维的固定输入,Java在 2到4秒中执行此操作,而Python需要在15到200秒之间。
我对Python的执行时间有点怀疑。这个Python代码中有什么我想念的吗?是否有任何算法改进建议(例如预分配/重用内存,降低Big-Oh复杂性的方法等)?
爪哇
double random_points[][] = new double[number_of_points][dimensions];
for(i = 0; i < number_of_points; i++)
for(d = 0; d < dimensions; d++)
random_points[i][d] = Math.random();
double p1[], p2[];
for(i = 0; i < number_of_points; i++)
{
p1 = random_points[i];
for(j = i + 1; j < number_of_points; j++)
{
p2 = random_points[j];
double sum_of_squares = 0;
for(d = 0; d < DIM_; d++)
sum_of_squares += (p2[d] - p1[d]) * (p2[d] - p1[d]);
double distance = Math.sqrt(ss);
if(distance > SOME_THRESHOLD) continue;
//...else do something with p1 and p2
}
}
Python 3.2
random_points = [[random.random() for _d in range(0,dimensions)] for _n in range(0,number_of_points)]
for i, p1 in enumerate(random_points):
for j, p2 in enumerate(random_points[i+1:]):
distance = math.sqrt(sum([(p1[d]-p2[d])**2 for d in range(0,dimensions)]))
if distance > SOME_THRESHOLD: continue
#...else do something with p1 and p2
答案 0 :(得分:5)
您可能需要考虑使用numpy。
我刚试过以下内容:
import numpy
from scipy.spatial.distance import pdist
D=2
N=3000
p=numpy.random.uniform(size=(N,D))
dist=pdist(p, 'euclidean')
最后一行计算距离矩阵(这相当于代码中为每对点计算distance
)。在我的电脑上大概需要0.07秒。
这种方法的主要缺点是距离矩阵需要O(n^2)
内存。如果这是一个问题,以下可能是更好的选择:
for i in xrange(1, N):
v = p[:N-i] - p[i:]
dist = numpy.sqrt(numpy.sum(numpy.square(v), axis=1))
for j in numpy.nonzero(dist > 1.4)[0]:
print j, i+j
对于N = 3000,我的电脑需要约0.33秒。
答案 1 :(得分:2)
如果我采取30K点和5维,这是多100倍的工作。
int number_of_points = 30000;
int dimensions = 5;
double SOME_THRESHOLD = 0.1;
long start = System.nanoTime();
double random_points[][] = new double[number_of_points][dimensions];
for (int i = 0; i < number_of_points; i++)
for (int d = 0; d < dimensions; d++)
random_points[i][d] = Math.random();
double p1[], p2[];
Comparator<double[]> compareX = new Comparator<double[]>() {
@Override
public int compare(double[] o1, double[] o2) {
return Double.compare(o1[0], o2[0]);
}
};
Arrays.sort(random_points, compareX);
double[] key = new double[dimensions];
int count = 0;
for (int i = 0; i < number_of_points; i++) {
p1 = random_points[i];
key[0] = p1[0] + SOME_THRESHOLD;
int index = Arrays.binarySearch(random_points, key, compareX);
if (index < 0) index = ~index;
NEXT: for (int j = i + 1; j < index; j++) {
p2 = random_points[j];
double sum_of_squares = 0;
for (int d = 0; d < dimensions; d++) {
sum_of_squares += (p2[d] - p1[d]) * (p2[d] - p1[d]);
if (sum_of_squares > SOME_THRESHOLD * SOME_THRESHOLD)
continue NEXT;
}
//...else do something with p1 and p2
count++;
}
}
long time = System.nanoTime() - start;
System.out.println("Took " + time / 1000 / 1000 + " ms. count= " + count);
打印
Took 1549 ms. count= 20197
答案 2 :(得分:1)
速度真的重要吗?这里有几个显而易见的加速:
不计算平方根。只需将您的阈值平方并与平方阈值进行比较。
沿一个维度(循环中的外部维度)对点进行排序。当两个点i
和j
比此维度的阈值更远时,进一步递增j
只会产生比此阈值更远的点,您可以continue
外环。
可能还有其他算法加速,即使上面仍然是O(n d )。