我有一个collections.deque()元组,我想从中抽取随机样本。
在Python 2.7中,我可以使用batch = random.sample(my_deque, batch_size)
。
但是在Python 3.4中,这会引发TypeError: Population must be a sequence or set. For dicts, use list(d).
什么是最佳的解决方法或推荐的方法从Python 3中的双端队列中有效地进行采样?
答案 0 :(得分:8)
显而易见的方法 - 转换为列表。
batch = random.sample(list(my_deque), batch_size))
但您可以避免创建整个列表。
idx_batch = set(sample(range(len(my_deque)), batch_size))
batch = [val for i, val in enumerate(my_deque) if i in idx_batch]
P.S。 (编者)
实际上,random.sample
应该可以在Python> = 3.5中使用deques。因为该类已更新以匹配Sequence接口。
In [3]: deq = collections.deque(range(100))
In [4]: random.sample(deq, 10)
Out[4]: [12, 64, 84, 77, 99, 69, 1, 93, 82, 35]
请注意!正如Geoffrey Irving在下面的评论中正确指出的那样,你最好将队列转换成一个列表,因为队列被实现为链表,使每个索引访问O(n)的队列大小,因此随机抽样值将花费O(m * n)时间。
答案 1 :(得分:0)
sample()
上的 deque
在Python≥3.5中工作正常,而且非常快。
在Python 3.4中,您可以改用它,它的运行速度差不多:
sample_indices = sample(range(len(deq)), 50)
[deq[index] for index in sample_indices]
在使用Python 3.6.8的MacBook上,该解决方案比Eli Korvigo的解决方案快44倍以上。 :)
我使用了deque
,其中包含100万个项目,并抽样了50个项目:
from random import sample
from collections import deque
deq = deque(maxlen=1000000)
for i in range(1000000):
deq.append(i)
sample_indices = set(sample(range(len(deq)), 50))
%timeit [deq[i] for i in sample_indices]
1.68 ms ± 23.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit sample(deq, 50)
1.94 ms ± 60.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit sample(range(len(deq)), 50)
44.9 µs ± 549 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit [val for index, val in enumerate(deq) if index in sample_indices]
75.1 ms ± 410 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
正如其他人指出的那样,deque
不太适合随机访问。如果要实现重播内存,可以改用如下所示的循环列表:
class ReplayMemory:
def __init__(self, max_size):
self.buffer = [None] * max_size
self.max_size = max_size
self.index = 0
self.size = 0
def append(self, obj):
self.buffer[self.index] = obj
self.size = min(self.size + 1, self.max_size)
self.index = (self.index + 1) % self.max_size
def sample(self, batch_size):
indices = sample(range(self.size), batch_size)
return [self.buffer[index] for index in indices]
百万个项目中,有50个项目的采样速度非常快:
%timeit mem.sample(50)
#58 µs ± 691 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)