我已经从scala转移到python来训练深度学习模型。我错过了数据操作功能方法的一些表现力,尤其是group by
。
我想将一组/文件列表(pathlib.Path对象)拆分为2组:验证和培训。我有一个函数which_set
,它为每个文件关联一个集合名称。在scala我会写:
>>> all_paths = TRAIN_PATH.glob('*')
>>> all_paths.groupby(which_set)
{'valid': [....], 'train': [....]}
在python中我挣扎。我可以使用pandas及其groupby
方法,但之后我必须将all_paths
转换为字符串。不理想。我可以使用itertools,但它为每个元素多次调用which_set:
from itertools import groupby
{k : list(l) for k, l in groupby(sorted(all_paths, key=which_set), key=which_set)}
这段代码似乎更加pythonic但它不起作用,而且很难阅读(与scala版本相比)
paths = {}
for x in all_paths: paths.setdefault(which_set(x), []).append(x)
你知道任何图书馆或方法可以帮助我吗?