用于生成具有特定属性的唯一元组列表的算法

时间:2011-08-16 09:31:39

标签: python algorithm

我需要用唯一元组(A,B,C)标记N个对象,其中A< B< C和相同的As的最大数量是M.对于Bs和Cs各自相同。在所有解决方案中,搜索具有最低C值的一个。 (这最后一句话的意思是:如果两个解决方案中的一个具有最高C为4而另一个为5,那么第一个是正确答案。)

示例:

M = 1
N = 4
#          As Bs Cs
objects = [(1, 2, 3), 
           (2, 3, 4), 
           (3, 4, 5), 
           (4, 5, 6)]
M = 2
N = 4
objects = [(1, 2, 3),
           (1, 2, 4),
           (2, 3, 4),
           (2, 3, 5)]
# or e.g
objects = [(1, 2, 3), 
           (2, 3, 4), 
           (2, 4, 5), 
           (3, 4, 5)]

M = 3
N = 8
objects = [(1, 2, 3), 
           (2, 3, 4), 
           (2, 3, 5), 
           (2, 4, 5), 
           (3, 4, 5), 
           (3, 4, 6), 
           (3, 5, 6), 
           (4, 5, 6)]

我想出的程序是一个复杂的if else怪物:

import sys
# useage: labelme.py <N> <M>
class ObjectListTree(object):
    """Create many possible paths. 
    Store the parent in each node. 
    The last nodes are appended to the class wide endnodes.
    """
    endnodes = []
    def __init__(self, parent, label, counter, n, M, N):
        self.parent = parent
        self.M = M
        self.N = N
        self.label = label
        self.counter = counter
        self.n = n
        if n < N:    
            self.inc_a()
            self.inc_b() 
            self.inc_c()
        else:
            ObjectListTree.endnodes.append(self)

    def inc_a(self):
        if self.label[0]+1 < self.label[1]:
            if self.counter[1] < self.M:
                if self.counter[2] < self.M:
                    self.plus_1()
                else:
                    self.plus_1_3()
            else:
                if self.counter[2] < self.M:
                    self.plus_1_2()
                else:
                    self.plus_all()
        elif self.label[1]+1 < self.label[2]:
            if self.counter[2] < self.M:
                self.plus_1_2()
            else:
                self.plus_all()
        else:
            self.plus_all()

    def inc_b(self):
        if self.counter[0] == self.M:
            return
        if self.label[1]+1 < self.label[2] and self.counter[2] < self.M:
            self.plus_2()
        else:
            self.plus_2_3()

    def inc_c(self):
        if self.counter[0] == self.M or self.counter[1] == self.M:
            return
        else:
            self.plus_3()

    def plus_all(self):
        ObjectListTree(self, (self.label[0]+1, self.label[1]+1, self.label[2]+1),
                       counter = [1, 1, 1,],
                       n = self.n+1, N=self.N, M=self.M)
    def plus_1_2(self):
        ObjectListTree(self, (self.label[0]+1, self.label[1]+1, self.label[2]),
                       counter = [1, 1, self.counter[2]+1,],
                       n = self.n+1, N=self.N, M=self.M)
    def plus_1_3(self):
        ObjectListTree(self, (self.label[0]+1, self.label[1], self.label[2]+1),
                       counter = [1, self.counter[1]+1, 1,],
                       n = self.n+1, N=self.N, M=self.M)
    def plus_1(self):
        ObjectListTree(self, (self.label[0]+1, self.label[1], self.label[2]),
                       counter = [1, self.counter[1]+1, self.counter[2]+1,],
                       n = self.n+1, N=self.N, M=self.M)
    def plus_2(self):
        ObjectListTree(self, (self.label[0], self.label[1]+1, self.label[2]),
                       counter = [self.counter[0]+1, 1, self.counter[2]+1,],
                       n = self.n+1, N=self.N, M=self.M)
    def plus_2_3(self):
        ObjectListTree(self, (self.label[0], self.label[1]+1, self.label[2]+1),
                       counter = [self.counter[0]+1, 1, 1,],
                       n = self.n+1, N=self.N, M=self.M)
    def plus_3(self):
        ObjectListTree(self, (self.label[0], self.label[1], self.label[2]+1),
                       counter = [self.counter[0]+1, self.counter[1]+1, 1,],
                       n = self.n+1, N=self.N, M=self.M)

tree = ObjectListTree(parent=None, label=(1, 2, 3), counter = [1,1,1,], n=1, N=int(sys.argv[1]), M=int(sys.argv[2]))

best_path = tree.endnodes[0]
for n in tree.endnodes:
    if n.label[2] < best_path.label[2]:
        best_path = n
objects = []
while best_path:
    objects.append(best_path.label)
    best_path = best_path.parent
objects.reverse()
print objects 

但是我觉得这实际上应该是简单的东西,比如包含在集合中的itertools模块中的两个或三个函数的组合。任何人都可以看到一个简单的解决方案吗?

1 个答案:

答案 0 :(得分:3)

我认为这段代码符合您的要求,并始终以尽可能低的C生成解决方案。但是,并不完全使用itertools。

def generateTuples(N, M):
  done = 0
  counters = {}
  for C in range(3, N + 3):
    for B in range(2, C):
      for A in range(1, B):
        if (counters.get('A%i' % A, 0) < M and
            counters.get('B%i' % B, 0) < M and
            counters.get('C%i' % C, 0) < M):
          yield (A, B, C)
          counters['A%i' % A] = counters.get('A%i' % A, 0) + 1
          counters['B%i' % B] = counters.get('B%i' % B, 0) + 1
          counters['C%i' % C] = counters.get('C%i' % C, 0) + 1
          done += 1
          if done >= N:
            return

for (A, B, C) in generateTuples(8, 3):
  print (A, B, C)