我编写了一个函数,它从非均匀分布中提取元素并返回输入数组元素的索引,就好像它们是从均匀分布中拉出来的一样。这是代码和示例:
import numpy as np
def uniform_choice(x, n):
unique, counts = np.unique(x, return_counts=True)
element_freq = np.zeros(x.shape)
for i in range(len(unique)):
element_freq[np.where(x == unique[i])[0]] = counts[i]
p = 1/element_freq/(1/element_freq).sum()
return np.random.choice(x, n, False, p)
x = np.random.choice(["a", "b", "c", "d"], 100000, p=(0.1, 0.2, 0.3, 0.4))
#so this gives an non-uniform distribution of elements "a", "b", "c", "d"
np.unique(x, return_counts=True)
#returns
(array(['a', 'b', 'c', 'd'], dtype='<U1'),
array([10082, 19888, 30231, 39799]))
使用我的函数,我可以从这个分布中提取元素,并获得索引,好像它们是从均匀分布中拉出的一样:
np.unique(uniform_choice(x, 5000), return_counts=True)
#returns
array([23389, 90961, 78455, ..., 41405, 22894, 79686])
是否可以避免我的函数中的for循环。我需要在非常大的数组上进行很多次采样,因此这变得很慢。我相信比较的矢量化版本会给我更快的结果。
答案 0 :(得分:2)
你可以杀死循环部分,我假设这是最耗时的部分,通过扩展使用np.unique
来合并return_inverse=True
,这会给我们带来唯一的数字x
中每个唯一字符串的标签。然后可以将这些数字标签用作索引,引导我们进行element_freq
的矢量化计算。因此,loopy部分 -
unique, counts = np.unique(x, return_counts=True)
element_freq = np.zeros(x.shape)
for i in range(len(unique)):
element_freq[np.where(x == unique[i])[0]] = counts[i]
将替换为 -
unique, idx, counts = np.unique(x, return_inverse=True, return_counts=True)
element_freq = counts[idx]
运行时测试 -
In [18]: x = np.random.choice(["a", "b", "c", "d"], 100000, p=(0.1, 0.2, 0.3, 0.4))
In [19]: %%timeit
...: unique, counts = np.unique(x, return_counts=True)
...: element_freq = np.zeros(x.shape)
...: for i in range(len(unique)):
...: element_freq[np.where(x == unique[i])[0]] = counts[i]
...:
100 loops, best of 3: 18.9 ms per loop
In [20]: %%timeit
...: unique, idx, counts =np.unique(x,return_inverse=True, return_counts=True)
...: element_freq = counts[idx]
...:
100 loops, best of 3: 12.9 ms per loop
答案 1 :(得分:0)
也许是这样的(未经测试):
def uniform_choice(x, n):
unique = np.unique(x)
values = np.random.choice(x, n, False)
return np.searchsorted(x, values, sorter=np.argsort(x))
这将从唯一集生成n
个值,然后使用searchsorted
在原始数组中查找这些值,并返回其索引。
我期望从这种方法中得到的一个区别是,您只会在x
中获得每个值出现的第一个索引。也就是说,x
多次出现的值将始终由其出现的单个索引表示,而在原始代码中,它可能是多个。