我正在尝试编写一个函数,允许我灵活地在字典中的参数子集上运行网格搜索。我想要完成的具体行为如下:
def my_grid_searching_function(fiducial_dict, **param_iterators):
for params in desired_iterator:
fiducial_dict.update(params)
# compute chi^2
# write new fiducial_dict values and associated chi^2 value to disk
我的具体目标是找出如何撰写desired_iterator
。
函数my_grid_searching_function
接受任意关键字参数子集,每个参数都将被解释为fiducial_dict
的参数。
这似乎是itertools.product
的任务,但我遇到了一个问题。在下面的实现中,我能够使用product
将嵌套循环有效地转换为输入迭代器的 values 到一个循环中:
from itertools import product
def my_failed_grid_searching_function(fiducial_dict, **param_iterators):
desired_iterator = product(*list(param_iterators.values()))
for params in desired_iterator:
print(params)
fiducial_dict = {'x': 0, 'y': 0, 'z': 9}
my_failed_grid_searching_function(fiducial_dict, x=[4, 5, 6], y=[1, 2])
(1, 4)
(1, 5)
(1, 6)
(2, 4)
(2, 5)
(2, 6)
当然,问题在于输入param_iterators
已经普及到普通字典中,因此在my_failed_grid_searching_function
的命名空间内我不知道值的顺序是什么。
任何人都可以提供有关我如何撰写desired_iterator
的提示,以便产生足够的信息来更新fiducial_dict
,如上所示?
答案 0 :(得分:1)
由于您使用了任意关键字参数,因此您可以获取param_iterators
字典的键以同步您的param产品的位置。或者,我建议使用sklearn包来执行grid search。
无论如何,试试这个解决方案:
from itertools import product
def my_grid_searching_function(fiducial_dict, **param_iterators):
keys = param_iterators.keys()
desired_iterator = list(product(*list(param_iterators.values())))
for i in range(len(desired_iterator)):
print("Epoch: ", i)
for loc in range(len(desired_iterator[i])):
print(keys[loc], desired_iterator[i][loc])
# update your fiducial_dict here
my_grid_searching_function({'x': 0, 'y': 0}, x=[1,2,3,4], y=[6,7,8])
输出:
('Epoch: ', 0)
('y', 6)
('x', 1)
('Epoch: ', 1)
('y', 6)
('x', 2)
('Epoch: ', 2)
('y', 6)
('x', 3)
('Epoch: ', 3)
('y', 6)
('x', 4)
('Epoch: ', 4)
('y', 7)
('x', 1)
('Epoch: ', 5)
('y', 7)
('x', 2)
('Epoch: ', 6)
('y', 7)
('x', 3)
('Epoch: ', 7)
('y', 7)
('x', 4)
('Epoch: ', 8)
('y', 8)
('x', 1)
('Epoch: ', 9)
('y', 8)
('x', 2)
('Epoch: ', 10)
('y', 8)
('x', 3)
('Epoch: ', 11)
('y', 8)
('x', 4)
***Repl Closed***
答案 1 :(得分:1)
感谢Scratch' N' Purr指出序列顺序可以简单地从.keys()方法确定。
from itertools import product
def param_grid_search_generator(**param_iterators):
param_names = list(param_iterators.keys())
param_combination_generator = product(*list(param_iterators.values()))
for param_combination in param_combination_generator:
yield {param_names[i]: param_combination[i] for i in range(len(param_names))}