我有一个令牌数组,每个令牌对应于从1
到n
的不同类。我需要 balance tokens
数组/列表,以便每个类有相等数量的标记。我想通过删除tokens
的元素来做到这一点。
在下面的示例中,令牌数量最少的类是class 2
,它只有2
个令牌。因此,我想从其他类中删除元素,直到它们的数量也为2
。
例如
tokens = array(['a','b','c','d','e','f','g','h','l'])
classes = array([ 1 , 1 , 1 , 1 , 2 , 2 , 3 , 3 , 3])
在此示例中,类以升序排列(为了清楚起见),但实际上,这些类没有特定的顺序。
例如
sol = array(['c','d','e','f','g','h'])
或
sol = array(['a','b','e','f','g','h'])
等
很明显,因为您可以选择要删除的多余元素,所以可以有不同的解决方案(如上)。我需要一个可以使用tokens
和classes
并输出sol
的函数。
答案 0 :(得分:2)
使用Counter
的解决方案:
tokens = ['a','b','c','d','e','f','g','h','l']
lst = [ 1 , 1 , 1 , 1 , 2 , 2 , 3 , 3 , 3]
from collections import Counter
c = Counter(lst)
min_cnt = min(c.values())
new_lst = list( zip(tokens, lst) )
while True:
tmp = []
should_break = True
for t, i in new_lst:
if c[i] > min_cnt:
c[i] -= 1
should_break = False
else:
tmp.append( (t, i) )
new_lst = tmp
if should_break:
break
print([t for t, _ in new_lst])
打印:
['c', 'd', 'e', 'f', 'h', 'l']
使用groupby
的其他可能解决方案:
tokens = ['a','b','c','d','e','f','g','h','l']
lst = [ 1 , 1 , 1 , 1 , 2 , 2 , 3 , 3 , 3]
from collections import Counter
from itertools import groupby, islice
c = Counter(lst)
min_cnt = min(c.values())
out = []
for v, g in groupby(sorted(enumerate(zip(tokens, lst)), key=lambda k: k[1][1]), lambda k: k[1][1]):
out.extend(islice(g, 0, min_cnt))
print( [val for _, (val, _) in sorted(out, key=lambda k: k[0])] )
打印:
['a', 'b', 'e', 'f', 'g', 'h']
答案 1 :(得分:1)
这是使用NumPy做到这一点的一种方法。这将始终选择每个类的第一个外观。
import numpy as np
def balance(tokens, classes):
# Count appearances of each class
c = np.bincount(classes - 1)
n = c.min()
# Accumulated counts for each class shifted one position
cs = np.roll(np.cumsum(c), 1)
cs[0] = 0
# Compute appearance index for each class
i = np.arange(len(classes)) - cs[classes - 1]
# Mask excessive appearances
m = i < n
# Return corresponding tokens
return tokens[m]
tokens = np.array(['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'l'])
classes = np.array([ 1, 1, 1, 1, 2, 2, 3, 3, 3])
print(balance(tokens, classes))
# ['a' 'b' 'e' 'f' 'g' 'h']
就目前而言,当某些类完全丢失时(因为最小出现次数为零,因此解决方案中不会出现类),该函数将返回一个空数组,但是您可以根据需要进行调整。
>答案 2 :(得分:1)
又一个简短的解决方案:
import random
from itertools import chain
from operator import itemgetter
import toolz
tokens = ['a','b','c','d','e','f','g','h','l']
classes = [ 1 , 1 , 1 , 1 , 2 , 2 , 3 , 3 , 3]
groups = toolz.groupby(itemgetter(1), zip(tokens, classes))
max_size = len(min(groups.values(), key=len))
random_samples = chain.from_iterable(map(lambda x: random.sample(x, k=max_size), list(groups.values())))
chosen_tokens, corresponding_classes = list(zip(*random_samples))
或完全使用buildins
个模块
import random
from itertools import chain, groupby, tee
from operator import itemgetter
tokens = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'l']
classes = [1, 1, 1, 1, 2, 2, 3, 3, 3]
groups_for_max_size, groups = tee(groupby(zip(tokens, classes), itemgetter(1)), 2)
max_size = len(min(groups_for_max_size, key = len))
random_samples = chain.from_iterable(map(lambda x: random.sample(list(x[1]), k = max_size), groups))
chosen_tokens, corresponding_classes = list(zip(*random_samples))
编辑:我认为还有一个更短的解决方案:
from itertools import chain, groupby
from operator import itemgetter
groups = (sorted(tokens, key=lambda x: random.random())
for _, tokens in groupby(zip(tokens, classes), itemgetter(1)))
chosen_tokens, corresponding_classes = zip(*chain.from_iterable(zip(*groups)))
只有两个步骤:1.确保每个组的列表都是随机的(这在sorted(tokens, key=lambda x: random.random())
中神奇地发生了,因为排序键始终是一个随机值)。
2.同样重要的是要知道zip
对元素进行采样,直到用尽最短的生成器为止(这使该解决方案变得如此之短)。 zip(*groups)
是一个迭代器,它在每次迭代中检索三元组(3个类)。由于我们事先对列表进行了混洗,因此对它们进行了随机采样。如果我们要再次分隔标记和类,则将三元组连接起来并再次解压缩。
答案 3 :(得分:1)
使用Counter
的另一种解决方案:
import random
from collections import Counter
tokens = np.array(['a','b','c','d','e','f','g','h','l'])
classes = np.array([ 1 , 1 , 1 , 1 , 2 , 2 , 3 , 3 , 3])
def sampling(tokens, classes):
dc = {}
sol = []
for i in range(len(classes)):
if classes[i] in dc:
dc[classes[i]].append(tokens[i])
else:
dc[classes[i]] = [tokens[i]]
sample_counts = Counter(classes)
min_sample = min(sample_counts.values())
for i in dc:
sol += (random.sample(dc[i],min_sample))
return sol
print(sampling(tokens, classes))
>>> ['d', 'a', 'f', 'e', 'g', 'h']