用Java优化Collat​​z猜想

时间:2016-09-12 13:46:42

标签: java performance optimization collatz

我正在研究一个程序,它确定使用Collat​​z猜想将数字变为1所需的步数(如果n为奇数,3n + 1;如果n为偶数,则为n / 2)。该程序在每次完成计算时将计算的数量增加一,并测试它可以以秒为单位计算的数量。这是我目前的工作计划:

public class Collatz {
    static long numSteps = 0;
    public static long calculate(long c){
        if(c == 1){
            return numSteps;
        }
        else if(c % 2 == 0){
            numSteps++;
            calculate(c / 2);
        }
        else if(c % 2 != 0){
            numSteps++;
            calculate(c * 3 + 1);
        }
        return numSteps;
    }
    public static void main(String args[]){
        int n = 1;
        long startTime = System.currentTimeMillis();
        while(System.currentTimeMillis() < startTime + 60000){

            calculate(n);
            n++;
            numSteps = 0;
        }
        System.out.println("The highest number was: " + n);
    }
}

目前它可以在一分钟内计算出大约1亿个数字,但我正在寻找有关如何进一步优化程序的建议,以便它可以在一分钟内计算出更多的数字。任何和所有建议将不胜感激:)。

1 个答案:

答案 0 :(得分:1)

你可以

  • 通过假设c % 2 == 0为假而c % 2 != 0必须为真来优化计算方法。您还可以假设c * 3 + 1必须是偶数,这样您就可以计算(c * 3 + 1)/2并将两个加到numSteps中。您可以使用循环而不是递归,因为Java没有尾调用优化。

  • 通过记忆获得更大的进步。对于每个数字,您可以记住您获得的结果,如果在返回该值之前计算了数字。您可能希望在记忆中设置上限,例如不高于您想要计算的最后一个数字。如果你不这样做,那么一些价值将是最大价值的许多倍。

为了您的兴趣

public class Collatz {
    static final int[] CALC_CACHE = new int[2_000_000_000];

    static int calculate(long n) {
        int numSteps = 0;
        long c = n;
        while (c != 1) {
            if (c < CALC_CACHE.length) {
                int steps = CALC_CACHE[(int) c];
                if (steps > 0) {
                    numSteps += steps;
                    break;
                }
            }
            if (c % 2 == 0) {
                numSteps++;
                c /= 2;
            } else {
                numSteps += 2;
                if (c > Long.MAX_VALUE / 3)
                    throw new IllegalStateException("c is too large " + c);
                c = (c * 3 + 1) / 2;
            }
        }
        if (n < CALC_CACHE.length) {
            CALC_CACHE[(int) n] = numSteps;
        }
        return numSteps;
    }

    public static void main(String args[]) {
        long n = 1, maxN = 0, maxSteps = 0;
        long startTime = System.currentTimeMillis();
        while (System.currentTimeMillis() < startTime + 60000) {
            for (int i = 0; i < 10; i++) {
                int steps = calculate(n);
                if (steps > maxSteps) {
                    maxSteps = steps;
                    maxN = n;
                }
                n++;
            }
            if (n % 10000000 == 1)
                System.out.printf("%,d%n", n);
        }
        System.out.printf("The highest number was: %,d, maxSteps: %,d for: %,d%n", n, maxSteps, maxN);
    }
}

打印

The highest number was: 1,672,915,631, maxSteps: 1,000 for: 1,412,987,847

更高级的答案是使用多个线程。在这种情况下,使用带有记忆的递归更容易实现。

import java.util.stream.LongStream;

public class Collatz {
    static final short[] CALC_CACHE = new short[Integer.MAX_VALUE-8];

    public static int calculate(long c) {
        if (c == 1) {
            return 0;
        }
        int steps;
        if (c < CALC_CACHE.length) {
            steps = CALC_CACHE[(int) c];
            if (steps > 0)
                return steps;
        }
        if (c % 2 == 0) {
            steps = calculate(c / 2) + 1;
        } else {
            steps = calculate((c * 3 + 1) / 2) + 2;
        }
        if (c < CALC_CACHE.length) {
            if (steps > Short.MAX_VALUE)
                throw new AssertionError();
            CALC_CACHE[(int) c] = (short) steps;
        }
        return steps;
    }

    static int calculate2(long n) {
        int numSteps = 0;
        long c = n;
        while (c != 1) {
            if (c < CALC_CACHE.length) {
                int steps = CALC_CACHE[(int) c];
                if (steps > 0) {
                    numSteps += steps;
                    break;
                }
            }
            if (c % 2 == 0) {
                numSteps++;
                c /= 2;
            } else {
                numSteps += 2;
                if (c > Long.MAX_VALUE / 3)
                    throw new IllegalStateException("c is too large " + c);
                c = (c * 3 + 1) / 2;
            }
        }
        if (n < CALC_CACHE.length) {
            CALC_CACHE[(int) n] = (short) numSteps;
        }
        return numSteps;
    }

    public static void main(String args[]) {
        long maxN = 0, maxSteps = 0;
        long startTime = System.currentTimeMillis();
        long[] res = LongStream.range(1, 6_000_000_000L).parallel().collect(
                () -> new long[2],
                (long[] arr, long n) -> {
                    int steps = calculate(n);
                    if (steps > arr[0]) {
                        arr[0] = steps;
                        arr[1] = n;
                    }
                },
                (a, b) -> {
                    if (a[0] < b[0]) {
                        a[0] = b[0];
                        a[1] = b[1];
                    }
                });
        maxN = res[1];
        maxSteps = res[0];
        long time = System.currentTimeMillis() - startTime;
        System.out.printf("After %.3f seconds, maxSteps: %,d for: %,d%n", time / 1e3, maxSteps, maxN);
    }
}

打印

After 52.461 seconds, maxSteps: 1,131 for: 4,890,328,815

注意:如果我将第二次计算调用更改为

     steps = calculate((c * 3 + 1) ) + 1;

打印

After 63.065 seconds, maxSteps: 1,131 for: 4,890,328,815