求解x ^ 2 + y ^ 2 + z ^ 2 = N以得到x,y,z的所有唯一组合

时间:2017-12-01 10:14:51

标签: python algorithm math discrete-mathematics

问题是我必须找到所有可能的整数组合(x,y,z),当你得到一个整数N时,它将满足公式x ^ 2 + y ^ 2 + z ^ 2 = N.你必须找到所有独特的元组(x,y,z)。例如,如果元组之一是(1,2,1),则(2,1,1)不再是唯一的。

def find(n):

    ## max of x can be sqrt of n
    n1 = math.ceil(n ** (1/2))

    lst = []

    for i in range(1, n1):

        if (i ** 2) >= n:
           break

        for j in range(1, (i + 1)):
           if (i ** 2 + j ** 2) >= n:
               break
           if (i ** 2 + i ** 2 + i ** 2) < n:
               break

           for k in range(1, (j + 1)):
              if (i ** 2 + j ** 2 + j ** 2) < n:  
                 break
              if (i ** 2 + j ** 2 + k ** 2) > n:       
                 break
              if i ** 2 + j ** 2 + k ** 2 == n:
                 a = [i, j, k]
                 lst.append(a)
    return lst          

我试图做一些优化。例如,当不可能满足等式时,我将停止循环。但它仍然没有优化。我有一个8位整数的测试用例,例如12345678.我的代码需要很长时间来解决8位整数。

是否可以进行其他优化? 谢谢!

2 个答案:

答案 0 :(得分:3)

以下是@ vishal-wadhwa的一些小优化。

您可以检查除以8的任何平方给出0或4或1的残差。 使用这个你可以例如立即告诉我,如果n除以8有残差7就没有解决方案。

如果n是4的倍数,则只有偶数解,所以你可以除以4并求解。

最后,如果n除以4得到残差1,2或3,则解中必须有1,2或3个奇数,以及2个,1个或0个偶数。

下面的代码使用了所有这些来削减一些角落。

import math

def new_find(n):
    if n == 0:
        return [(0, 0, 0)]
    f = 1
    while n % 4 == 0:
        n //= 4
        f *= 2
    if f > 1:
        return [(f*a, f*b, f*c) for (a, b, c) in find(n)]
    if n % 8 == 7:
        return []

    offs = 0 if n % 4 == 1 else 1
    split = 3 if n % 4 == 3 else 2

    sol = []
    for i in range(offs, n, 2):
        if i*i > n // split:
            break
        for j in range(i, n, 2):
            if j*j > (n - i*i) / (split-1):
                break
            rem = n - i*i - j*j
            rs = int(math.sqrt(rem))
            if rem == rs*rs:
                sol.append((i, j, rs))
    return sol

def orig_find(n):
    lst = set()
    n1 = int(math.ceil(math.sqrt(n)))
    for i in range(0, n1):
        for j in range(i, n1):

            if i**2 + j**2 >= n:
                break

            tz = int(math.sqrt(n - i**2 - j**2))
            if tz ** 2 == n - i**2 - j**2:

                #This if-block makes the elements in tuple, sorted.
                #so that set can compare two values for equality
                #because (1, 2, 3) != (2, 3, 1)
                if tz < i:
                    i, j, tz = tz, i, j
                elif tz < j:
                    j, tz = tz, j

                a = (i, j, tz)
                lst.add(a) 
    return lst

def check_equal(a, b):
    import operator
    return all(map(operator.eq, sorted(map(sorted, a)), sorted(map(sorted, b))))

print(check_equal(new_find(12345678), orig_find(12345678)))

# True

答案 1 :(得分:2)

更新

这是最近的工作代码,它提供了唯一的值并使用了list

import math


def find(n):
    # max of x can be sqrt of n
    n1 = int(math.ceil(math.sqrt(n)))
    lst = list()
    for i in range(1, n1):
        for j in range(i + 1, n1):

            if i**2 + j**2 >= n:
                break

            tz = int(math.sqrt(n - i**2 - j**2))
            if tz ** 2 == n - i**2 - j**2:

                #This if-block makes the elements in tuple, sorted.
                #so that set can compare two values for equality
                #because (1, 2, 3) != (2, 3, 1)
                if tz < i:
                    i, j, tz = tz, i, j
                elif tz < j:
                    j, tz = tz, j
                lst.append((i, j, tz))
    lst.sort()
    j = 0
    i = 0
    while i < len(lst):
        lst[j] = lst[i]
        if lst[i] != lst[j]:
            i = i + 1
        while i < len(lst) and lst[i] == lst[j]:
            i = i + 1
        j = j + 1

    lst = lst[:j]
    return lst

print find(12345678)

<小时/> 您的解决方案绝对可以优化。不需要k迭代的循环。您可以使用简单的数学来摆脱第三个循环。

我的python技能有点生疏,但以下代码有效。

import math


def find(n):
    # max of x can be sqrt of n
    n1 = int(math.ceil(math.sqrt(n)))
    lst = []

    for i in range(1, n1):
        for j in range(i + 1, n1):

            if i**2 + j**2 >= n:
                break

            tz = int(math.sqrt(n - i**2 - j**2))
            if tz ** 2 == n - i**2 - j**2:
                a = [i, j, tz]
                lst.append(a)
    return lst

print find(12345678)

经过测试:here

对于唯一值,您可以使用不同的数据结构,例如维护唯一值的集合,或者使用另一个if-check来查看是​​否已经存在类似的条目。

lst = set()

for i in range(1, n1):
    for j in range(i + 1, n1):

        if i**2 + j**2 >= n:
            break

        tz = int(math.sqrt(n - i**2 - j**2))
        if tz ** 2 == n - i**2 - j**2:

            #This if-block makes the elements in tuple, sorted.
            #so that set can compare two values for equality
            #because (1, 2, 3) != (2, 3, 1)
            if tz < i:
                i, j, tz = tz, i, j
            elif tz < j:
                j, tz = tz, j

            a = (i, j, tz)
            lst.add(a)

经过测试:here

OR

如果您想继续使用该列表,只有当列表中没有当前的值集时,您才可以添加另一个检查并附加到列表中。

if tz < i:
    i, j, tz = tz, i, j
elif tz < j:
    j, tz = tz, j
a = [i, j, tz]
if not(a in lst):
    lst.append(a)

两者都是以增加时间复杂性为代价的。但是,使用set的IMO非常适合,因为它的平均大小写复杂度为O(1),而在list中搜索元素时,它的存在是O(n)