这链接到Cartesian product of nested dictionaries of lists
假设我有一个嵌套字典,其中包含代表多个配置的列表,例如:
{'algorithm': ['PPO', 'A2C', 'DQN'],
'env_config': {'env': 'GymEnvWrapper-Atari',
'env_config': {'AtariEnv': {'game': ['breakout', 'pong']}}}
目标是计算嵌套字典中列表的笛卡尔积,以获得所有可能的配置。
这是我目前得到的:
def product(*args, repeat=1, root=False):
pools = [tuple(pool) for pool in args] * repeat
result = [[]]
for pool in pools:
result = [x+[y] for x in result for y in pool]
print("************************")
print(root)
for r in result:
print(tuple(r))
print("************************")
for prod in result:
yield tuple(prod)
def recursive_cartesian_product(dic, root=True):
# based on https://stackoverflow.com/a/50606871/11051330
# added differentiation between list and entry to protect strings in dicts
# with uneven depth
keys, values = dic.keys(), dic.values()
vals = (recursive_cartesian_product(v, False) if isinstance(v, dict)
else v if isinstance(v, list) else (v,) for v in
values)
print("!", root)
for conf in product(*vals, root=root):
print(conf)
yield dict(zip(keys, conf))
这是相关的输出:
************************
True
('PPO', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'breakout'}}})
('PPO', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}})
('A2C', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'breakout'}}})
('A2C', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}})
('DQN', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'breakout'}}})
('DQN', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}})
************************
('PPO', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'breakout'}}})
('PPO', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}})
('A2C', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}})
('A2C', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}})
('DQN', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}})
('DQN', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}})
请注意 product
中的打印语句如何正常工作,而 yield
中的打印失败并且不会改变后面配置的 env
值。
答案 0 :(得分:0)
原来问题不在上面的函数内部,而是在它的外部。生成的 conf 被传递给一个函数 **kwargs
,这弄乱了生成器。
这是一个快速的解决方案:
def recursive_cartesian_product(dic):
# based on https://stackoverflow.com/a/50606871/11051330
# added differentiation between list and entry to protect strings
# yield contains deepcopy. important as use otherwise messes up generator
keys, values = dic.keys(), dic.values()
vals = (recursive_cartesian_product(v) if isinstance(v, dict)
else v if isinstance(v, list) else (v,) for v in
values)
for conf in itertools.product(*vals):
yield deepcopy(dict(zip(keys, conf)))
答案 1 :(得分:0)
itertools
已经有一个 product
类型:
from itertools import product
d = {'algorithm': ['PPO', 'A2C', 'DQN'],
'env_config': {'env': 'GymEnvWrapper-Atari',
'env_config': {'AtariEnv': {'game': ['breakout', 'pong']}}}
for algo, game in product(d['algorithm'],
d['env_config']['env_config']['AtariEnv']['game']):
print((algo, {'env': 'GymEnvWrapper-Atari',
'env_config': {'AtariEnv': {'game': game}}}))
答案 2 :(得分:0)
使用 itertools.product
确实比滚动自己更简单。
如果您不希望 env_config
发生变化(游戏名称除外),则无需实现通用递归 dict 访问者。
所以你只想要 algorithms
的产品和 game
名称,总是使用 AtariEnv
然后:
from itertools import product
possible_configurations = {'algorithm': ['PPO', 'A2C', 'DQN'],
'env_config': {'env': 'GymEnvWrapper-Atari',
'env_config': {'AtariEnv': {'game': ['breakout', 'pong']}}}}
algorithms = tuple(possible_configurations["algorithm"])
games = tuple(
{"env": "GymEnvWrapper-Atari", "env_config": {"AtariEnv": {"game": game_name}}}
for game_name in possible_configurations["env_config"]["env_config"]["AtariEnv"]["game"]
)
factors = (algorithms, games)
for config in product(*factors):
print(config)
如果您更喜欢通用解决方案,这是我的:
from itertools import product
possible_configurations = {'algorithm': ['PPO', 'A2C', 'DQN'],
'env_config': {'env': 'GymEnvWrapper-Atari',
'env_config': {'AtariEnv': {'game': ['breakout', 'pong']}}}}
def product_visitor(obj):
if isinstance(obj, dict):
yield from (
dict(possible_product)
for possible_product in product(
*(
[(key, possible_value) for possible_value in product_visitor(value)]
for key, value in obj.items())))
elif isinstance(obj, list):
for value in obj:
yield from product_visitor(value)
else: # either a string, a number, a boolean or null (all scalars)
yield obj
configs = tuple(product_visitor(possible_configurations))
print("\n".join(map(str, configs)))
assert configs == (
{'algorithm': 'PPO', 'env_config': {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'breakout'}}}},
{'algorithm': 'PPO', 'env_config': {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}}},
{'algorithm': 'A2C', 'env_config': {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'breakout'}}}},
{'algorithm': 'A2C', 'env_config': {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}}},
{'algorithm': 'DQN', 'env_config': {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'breakout'}}}},
{'algorithm': 'DQN', 'env_config': {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}}},
)