一种有效地用组替换数组元素的算法

时间:2020-01-16 13:59:54

标签: python algorithm

有N个元素,每个元素都有其自己的成本。有M个小组。每个组都包含数组中元素的几个索引,并且都有自己的成本。

例如输入

6
100 5
200 5
300 5
400 5
500 5
600 3
2
4 6
100 200 300 700
3 5
300 400 500

第一个数字N是元素数。接下来的N行包含特定项目的索引和成本。然后是数字M(组数)。之后是2 * M行。这些行包含组中元素的数量,组本身的成本以及元素的索引。

我想找到可以购买全部N件商品的最低费用。

在此示例中,最有利的是同时接受两组并分别购买编号为600的元素。答案是14。(6 + 5 + 3)

这是我的解决方法

from queue import PriorityQueue

N = int(input())
dct = {}
groups = PriorityQueue()

for i in range(N):
    a,c = [int(j) for j in input().split()]
    dct[a] = c 

M = int(input())


for i in range(M): 
    k,c = [int(j) for j in input().split()]
    s = 0
    tmp = []
    for j in input().split():
        j_=int(j)
        if j_ in dct:
            s+=dct[j_]
            tmp.append(j_)
    d = c-s
    if d<0:
        groups.put([d, c, tmp])  

s = 0
while not groups.empty():
    #print(dct)
    #for i in groups.queue:
    #    print(i)
    g = groups.get()
    if g[0]>0:
        break
    #print('G',g)
    #print('-------')
    for i in g[2]:
        if i in dct:
            del(dct[i])
    s += g[1]
    groups_ = PriorityQueue()
    for i in range(len(groups.queue)):
            g_ = groups.get()
            s_ = 0
            tmp_ = []
            for i in g_[2]:
                if i in dct:
                    s_+=dct[i]
                    tmp_.append(i)
            d = g_[1]-s_
            groups_.put([d, g_[1], tmp_])
    groups = groups_ 

for i in dct:
    s+=dct[i]

print(s)

但这不是完全正确的。

例如,对于这样的测试,它给出的答案为162。但是正确的答案是160。仅采用第一组和第二组并分别采用索引为0的元素是最有益的。

20
0 24
1 32
2 33
3 57
4 57
5 50
6 50
7 41
8 2
9 73
10 81
11 73
12 55
13 3
14 54
15 43
16 98
17 8
18 41
19 97
5
17 61
17 9 11 15 1 13 14 7 20 2 3 16 12 5 8 4 6
13 75
20 15 5 9 10 11 7 8 18 2 4 19 16
10 96
3 9 4 18 11 6 8 5 2 14
9 92
18 1 6 9 19 8 4 16 10
19 77
14 17 18 3 2 4 7 6 8 9 10 20 13 12 15 19 1 16 5

我也尝试过蛮力搜索,但是这样的解决方案太慢了

from itertools import combinations

N = int(input())
dct = {}

s = 0
for i in range(N):
    a,c = [int(j) for j in input().split()]
    dct[a] = c
    s += c
m = s

M = int(input())

groups = []
for i in range(M):
    k,c = [int(j) for j in input().split()]
    s = 0
    tmp = []
    for j in input().split():
        j_=int(j)
        if j_ in dct:
            s+=dct[j_]
            tmp.append(j_)
    groups.append( [c, tmp] )

for u in range(1,M+1):
    for i in list(combinations(groups, u)): 
        s = 0
        tmp = dct.copy()
        for j in i:
            s += j[0]
            for t in j[1]:
                if t in tmp:
                    del(tmp[t])
        for j in tmp:
            s += tmp[j] 
        #print(i,s)
        if s < m:
            m = s  
print(m)

我认为可以通过动态编程解决此问题。也许这是典型背包问题的某些变体。告诉我哪种算法更好用。

2 个答案:

答案 0 :(得分:2)

所谓的set cover problem(即NP-Hard)似乎是您的问题的特例。因此,恐怕没有有效的算法可以解决它。

答案 1 :(得分:1)

如上所述,这是一个难题,没有“有效”算法。

您可以将其视为图问题,其中图的节点都是组的所有可能组合(其中每个元素本身也是一个组)。当存在一个组 g 时,两个节点 u v 与有向边相连,从而使 u中的键并集 g 中的 v 中的键集对应。

然后在该图中执行Dijkstra搜索,从表示根本没有选择任何组的状态的节点开始(成本0,没有键)。此搜索将最大程度地降低成本,并且您可以使用额外的优化功能,即在同一路径中永远不会将 g 组视为两次。一旦访问了涵盖所有密钥的状态(节点),您就可以退出算法(这是Dijkstra算法的典型做法),因为这代表覆盖所有密钥的最低成本。

这样的算法仍然非常昂贵,因为每次在路径上增加一条边时,都必须计算密钥的并集。而且,...需要相当多的内存才能将所有状态保留在堆中。

这是一个潜在的实现方式:

from collections import namedtuple
import heapq

# Some named tuple types, to make the code more readable
Group = namedtuple("Group", "cost numtodo keys")
Node = namedtuple("Node", "cost numtodo keys nextgroupid")

def collectinput():
    inputNumbers = lambda: [int(j) for j in input().split()]

    groups = []
    keys = []

    N, = inputNumbers()
    for i in range(N):
        key, cost = inputNumbers()
        keys.append(key)
        # Consider these atomic keys also as groups (with one key)
        # The middle element of this tuple may seem superficial, but it improves sorting
        groups.append(Group(cost, N-1, [key]))
    keys = set(keys)

    M, = inputNumbers()
    for i in range(M):
        cost = inputNumbers()[-1]
        groupkeys = [key for key in inputNumbers() if key in keys]
        groups.append(Group(cost, N-len(groupkeys), groupkeys))

    return keys, groups


def solve(keys, groups):
    N = len(keys)
    groups.sort() # sort by cost, if equal, by number of keys left 

    # The starting node of the graph search
    heap = [Node(0, N, [], 0)]

    while len(heap):
        node = heapq.heappop(heap)
        if node.numtodo == 0:
            return node.cost
        for i in range(node.nextgroupid, len(groups)):
            group = groups[i]
            unionkeys = list(set(node.keys + group.keys))
            if len(unionkeys) > len(node.keys):
                heapq.heappush(heap, Node(node.cost + group.cost, N-len(unionkeys), unionkeys, i+1))

# Main
keys, groups = collectinput()
cost = solve(keys, groups)
print("solution: {}".format(cost))

这将为您发布的第二个问题输出160。