如何按给定比例将一台发电机随机分成两台?

时间:2019-02-05 18:10:22

标签: python generator

假设有一个生成器,该生成器可生成一定数量的数据。 有没有办法获得两个生成相同数据的生成器,其中两个生成器生成的项的数量是按一定比例给出的,例如发生器1产生80%的数据,而发生器2产生20%的数据。这应该是随机发生的。

当然会生成第一个生成器的列表,将列表洗牌并将其分为两部分是直接的解决方案。但是,我想知道是否有可能在不必将整个内容存储到内存中的情况下实现这一目标。

最好, 马特

我想到的是

def split_generator(data_generator, percentage_gen_1):
    original_generator, generator_copy = tee(data_generator)
    n_entries = sum(1 for item in generator_copy)

    split_idx = int(n_entries * percentage_gen_1)

    gen_1 = islice(original_generator, 0, split_idx)

    # I found that the remaining part of original_generator works
    # as the remaining (1 - percentage_gen_1) part

    return gen_1, original_generator

这有两个缺点。它不是随机的,我认为tee将整个内容存储在内存中,因此没有理由过度转换为列表。

2 个答案:

答案 0 :(得分:1)

此解决方案不存储值。它设置两个相同的生成器和两个相同的随机数流。生成器共享相同的截止百分比,其中一个仅在其以下生成收益,而另一种仅在其之上生成收益:

from random import Random

def percentage_generators(generator, percentage):

    def generator_1(state):
        twister = Random()
        twister.setstate(state)

        for value in generator():
            if twister.random() < percentage:
                yield value

    def generator_2(state):
        twister = Random()
        twister.setstate(state)

        for value in generator():
            if twister.random() >= percentage:
                yield value

    state = Random().getstate()

    return [generator_1(state), generator_2(state)]

if __name__ == "__main__":

    def test_generator():
        for n in range(20):
            yield n

    generator1, generator2 = percentage_generators(test_generator, 0.7)

    for number in generator1:
        print(1, number)

    print()

    for number in generator2:
        print(2, number)

输出

% python3 test.py
1 0
1 1
1 2
1 3
1 6
1 7
1 8
1 10
1 11
1 12
1 13
1 14
1 15
1 17

2 4
2 5
2 9
2 16
2 18
2 19
%

通过循环生成生成器包装器(即在operator.ltoperator.ge或类似的循环中生成代码,可以减少代码。

答案 1 :(得分:0)

这是一种将一些生成器的值存储在内存中的方法,但不是全部。特别是,它仅将值存储在任一生成器最后生成的值之间。例如,当生成正整数时,如果a最后产生23,而b最后产生42,则仅24到41存储在存储器中。

from collections import deque
import random

def randsplit(g):
    g = iter(g)
    queues = [deque(), deque()]

    def fill_queues():
        x = next(g)
        if random.random() < 0.8:
            queues[0].append(x)
        else:
            queues[1].append(x)

    def iter_from_queue(q):
        while True:
            while not q:
                try:
                    fill_queues()
                except StopIteration:
                    return
            yield q.popleft()

    return [iter_from_queue(queues[0]), iter_from_queue(queues[1])]

a,b = randsplit(range(20))

print("iterating through a.")
for item in a: print(item)

print("iterating through b.")
for item in b: print(item)

一个可能的结果:

iterating through a.
0
3
4
5
6
7
8
9
11
12
13
14
15
16
17
18
19
iterating through b.
1
2
10