计算非空区域的快速方法

时间:2014-10-03 13:26:02

标签: python performance algorithm

我正在编写一些代码,这些代码在5个维度中选择n个随机超平面。然后,它在单位球上随机均匀地对no_points点进行采样,并计算由超平面创建的区域中至少有一个点。使用以下Python代码可以相对简单。

import numpy as np

def points_on_sphere(dim, N, norm=np.random.normal):
    """
    http://en.wikipedia.org/wiki/N-sphere#Generating_random_points
    """
    normal_deviates = norm(size=(N, dim))
    radius = np.sqrt((normal_deviates ** 2).sum(axis=0))
    points = normal_deviates / radius
    return points

n = 100
d = 5
hpoints = points_on_sphere(n, d).T
for no_points in xrange(0, 10000000,100000):
    test_points = points_on_sphere(no_points,d).T 
    #The next two lines count how many of the test_points are in different regions created by the hyperplanes
    signs = np.sign(np.inner(test_points, hpoints))
    print no_points, len(set(map(tuple,signs)))

不幸的是,我计算不同地区有多少点的天真方法很慢。总体而言,该方法需要O(no_points * n * d)时间,实际上,当no_points达到1000000时,它太慢而且RAM太忙。特别是它在no_points = 900,000达到4GB的RAM。

这可以更有效地完成,以便no_points能够相当快地并且使用少于4GB的RAM,一直可以达到10,000,000(实际上它可以达到10倍)吗?

1 个答案:

答案 0 :(得分:2)

存储每个测试点对每个超平面的分类方式是很多数据。我建议在点标签上使用隐式基数排序,例如,

import numpy as np


d = 5
n = 100
N = 100000
is_boundary = np.zeros(N, dtype=bool)
tpoints = np.random.normal(size=(N, d))
tperm = np.arange(N)
for i in range(n):
    hpoint = np.random.normal(size=d)
    region = np.cumsum(is_boundary) * 2 + (np.inner(hpoint, tpoints) < 0.0)[tperm]
    region_order = np.argsort(region)
    is_boundary[1:] = np.diff(region[region_order])
    tperm = tperm[region_order]
print(np.sum(is_boundary))

此代码保留测试点(tperm)的排列,使得同一区域中的所有点都是连续的。 boundary表示每个点是否与排列顺序中的前一个区域不同。对于每个连续的超平面,我们对每个现有区域进行分区,并有效地丢弃空区域,以避免存储2 ^ 100个区域。

实际上,由于你有很多分数和很少的超平面,所以不存储积分更有意义。以下代码使用二进制编码将区域签名打包为两个双精度。

import numpy as np


d = 5
hpoints = np.random.normal(size=(100, d))
bits = np.zeros((2, 100))
bits[0, :50] = 2.0 ** np.arange(50)
bits[1, 50:] = 2.0 ** np.arange(50)
N = 100000
uniques = set()
for i in xrange(0, N, 1000):
    tpoints = np.random.normal(size=(1000, d))
    signatures = np.inner(np.inner(tpoints, hpoints) < 0.0, bits)
    uniques.update(map(tuple, signatures))
print(len(uniques))