np.shuffle比np.random.choice

时间:2018-09-11 02:16:10

标签: python numpy

我有一个形状为(N,3)的数组,我想随机地对行进行随机排序。 N大约是100,000。

我发现np.random.shuffle阻碍了我的应用程序。我尝试通过调用np.random.choice代替随机播放,并经历了10倍的加速。这里发生了什么?为什么调用np.random.choice这么快? np.random.choice版本是否会生成均匀分布的随机播放?

import timeit

task_choice = '''
N = 100000
x = np.zeros((N, 3))
inds = np.random.choice(N, N, replace=False)
x[np.arange(N), :] = x[inds, :]
'''

task_shuffle = '''
N = 100000
x = np.zeros((N, 3))
np.random.shuffle(x)
'''

task_permute = '''
N = 100000
x = np.zeros((N, 3))
x = np.random.permutation(x)
'''

setup = 'import numpy as np'

timeit.timeit(task_choice, setup=setup, number=10)
>>> 0.11108078400138766

timeit.timeit(task_shuffle, setup=setup, number=10)
>>> 1.0411593900062144

timeit.timeit(task_permute, setup=setup, number=10)
>>> 1.1140159380011028

编辑:对于有好奇心的人,我决定采用以下解决方案,因为该解决方案可读性强并且优于基准测试中的所有其他方法:

task_ind_permute = '''
N = 100000
x = np.zeros((N, 3))
inds = np.random.permutation(N)
x[np.arange(N), :] = x[inds, :]
'''

2 个答案:

答案 0 :(得分:1)

您在这里比较非常个不同大小的数组。在第一个示例中,尽管创建了一个零数组,但您仅使用random.choice(100000, 100000),它会在1-100000之间抽取100000个随机值。在第二个示例中,您正在改组(100000, 3)形状数组。

>>> x.shape
(100000, 3)
>>> np.random.choice(N, N, replace=False).shape
(100000,)

对更多等效样本的计时:

In [979]: %timeit np.random.choice(N, N, replace=False)
2.6 ms ± 201 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [980]: x = np.arange(100000)

In [981]: %timeit np.random.shuffle(x)
2.29 ms ± 67.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [982]: x.shape == np.random.choice(N, N, replace=False).shape
Out[982]: True

答案 1 :(得分:1)

permutationshuffle是链接的,实际上permutation在幕后叫shuffle

对于多维数组,shufflepermutation慢的原因是,permutation仅需要shuffle沿第一轴的索引。因此成为1d数组{{if-1else的第一个块)的shuffle的特例。

此特殊情况在源代码中也有说明:

# We trick gcc into providing a specialized implementation for
# the most common case, yielding a ~33% performance improvement.
# Note that apparently, only one branch can ever be specialized.

另一方面,对于shuffle,多维ndarray操作需要一个反弹缓冲区,特别是在维数较大时,创建该缓冲区会变得昂贵。此外,我们将无法再使用上述有助于1d案例的技巧。

使用replace=False并使用choice生成相同大小的新数组,choicepermutation相同,请参见here。额外的时间必须来自创建中间索引数组所花费的时间。