带生成器的笛卡尔积

时间:2021-06-21 11:08:52

标签: python recursion generator cartesian

这链接到Cartesian product of nested dictionaries of lists

假设我有一个嵌套字典,其中包含代表多个配置的列表,例如:

{'algorithm': ['PPO', 'A2C', 'DQN'],
'env_config': {'env': 'GymEnvWrapper-Atari',
'env_config': {'AtariEnv': {'game': ['breakout', 'pong']}}}

目标是计算嵌套字典中列表的笛卡尔积,以获得所有可能的配置。

这是我目前得到的:

def product(*args, repeat=1, root=False):
    pools = [tuple(pool) for pool in args] * repeat
    result = [[]]
    for pool in pools:
        result = [x+[y] for x in result for y in pool]
    print("************************")
    print(root)
    for r in result:
        print(tuple(r))
    print("************************")
    for prod in result:
        yield tuple(prod)


def recursive_cartesian_product(dic, root=True):
    # based on https://stackoverflow.com/a/50606871/11051330
    # added differentiation between list and entry to protect strings in dicts
    # with uneven depth
    keys, values = dic.keys(), dic.values()

    vals = (recursive_cartesian_product(v, False) if isinstance(v, dict)
            else v if isinstance(v, list) else (v,) for v in
            values)

    print("!", root)
    for conf in product(*vals, root=root):
        print(conf)
        yield dict(zip(keys, conf))

这是相关的输出:

************************
True
('PPO', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'breakout'}}})
('PPO', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}})
('A2C', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'breakout'}}})
('A2C', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}})
('DQN', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'breakout'}}})
('DQN', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}})
************************
('PPO', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'breakout'}}})
('PPO', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}})
('A2C', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}})
('A2C', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}})
('DQN', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}})
('DQN', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}})

请注意 product 中的打印语句如何正常工作,而 yield 中的打印失败并且不会改变后面配置的 env 值。

3 个答案:

答案 0 :(得分:0)

原来问题不在上面的函数内部,而是在它的外部。生成的 conf 被传递给一个函数 **kwargs,这弄乱了生成器。

这是一个快速的解决方案:

def recursive_cartesian_product(dic):
    # based on https://stackoverflow.com/a/50606871/11051330
    # added differentiation between list and entry to protect strings
    # yield contains deepcopy. important as use otherwise messes up generator
    keys, values = dic.keys(), dic.values()

    vals = (recursive_cartesian_product(v) if isinstance(v, dict)
            else v if isinstance(v, list) else (v,) for v in
            values)

    for conf in itertools.product(*vals):
        yield deepcopy(dict(zip(keys, conf)))

答案 1 :(得分:0)

itertools 已经有一个 product 类型:

from itertools import product


d = {'algorithm': ['PPO', 'A2C', 'DQN'],
     'env_config': {'env': 'GymEnvWrapper-Atari',
                    'env_config': {'AtariEnv': {'game': ['breakout', 'pong']}}}

for algo, game in product(d['algorithm'],
                          d['env_config']['env_config']['AtariEnv']['game']):
    print((algo, {'env': 'GymEnvWrapper-Atari', 
                  'env_config': {'AtariEnv': {'game': game}}})) 

答案 2 :(得分:0)

使用 itertools.product 确实比滚动自己更简单。

如果您不希望 env_config 发生变化(游戏名称除外),则无需实现通用递归 dict 访问者。
所以你只想要 algorithms 的产品和 game 名称,总是使用 AtariEnv 然后:

from itertools import product

possible_configurations = {'algorithm': ['PPO', 'A2C', 'DQN'],
'env_config': {'env': 'GymEnvWrapper-Atari',
'env_config': {'AtariEnv': {'game': ['breakout', 'pong']}}}}

algorithms = tuple(possible_configurations["algorithm"])
games = tuple(
    {"env": "GymEnvWrapper-Atari", "env_config": {"AtariEnv": {"game": game_name}}}
    for game_name in possible_configurations["env_config"]["env_config"]["AtariEnv"]["game"]
)

factors = (algorithms, games)
for config in product(*factors):
    print(config)

如果您更喜欢通用解决方案,这是我的:

from itertools import product

possible_configurations = {'algorithm': ['PPO', 'A2C', 'DQN'],
'env_config': {'env': 'GymEnvWrapper-Atari',
'env_config': {'AtariEnv': {'game': ['breakout', 'pong']}}}}


def product_visitor(obj):
    if isinstance(obj, dict):
        yield from (
            dict(possible_product)
            for possible_product in product(
                *(
                    [(key, possible_value) for possible_value in product_visitor(value)]
                    for key, value in obj.items())))
    elif isinstance(obj, list):
        for value in obj:
            yield from product_visitor(value)
    else:  # either a string, a number, a boolean or null (all scalars)
        yield obj


configs = tuple(product_visitor(possible_configurations))
print("\n".join(map(str, configs)))
assert configs == (
    {'algorithm': 'PPO', 'env_config': {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'breakout'}}}},
    {'algorithm': 'PPO', 'env_config': {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}}},
    {'algorithm': 'A2C', 'env_config': {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'breakout'}}}},
    {'algorithm': 'A2C', 'env_config': {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}}},
    {'algorithm': 'DQN', 'env_config': {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'breakout'}}}},
    {'algorithm': 'DQN', 'env_config': {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}}},
)