Java迭代合并排序运行时

时间:2017-11-07 01:02:09

标签: java arrays sorting merge mergesort

我正在为学校做一个项目,要求我为不同的排序算法编写代码。最困难的部分是在给定长度为2 ^ N的输入数组的情况下编写合并排序的迭代版本。我使用了一个名为merge的必需辅助方法来帮助迭代合并。

我的结构如下。给定一个2 ^ N的数组(让我们使用16的数组来解释我的方法),我遍历数组查看每个2个整数,并使用merge()交换一个大于另一个的整数。该过程将在长度为16的阵列中发生8次。然后我会遍历数组,查看每个4个整数,4次。我会使用我的merge方法合并每组4中的两个有序对。然后,我会看一个8个整数的块...等等。我的代码发布在这里:

public static void MergeSortNonRec(long[] a) {
    //======================
    //FILL IN YOUR CODE HERE
    //======================    
    /*
    System.out.print("Our array is: ");
    printArray(a);
    System.out.println('\n');
    */
    int alength = a.length;
    int counter = 2;
    //the counter will iterate through levels 2n - 2 4 8 16 32 etc.
    int pointtracker = 0;
    //the point tracker will keep track of the position in the array
    while (counter <= alength) {
        long [] aux = new long [alength];
        int low = pointtracker;
        int high = pointtracker + counter - 1;
        int mid = (low + high)/2;

        merge(a, aux, low, mid, high);

        if (high < alength - 1) {
            pointtracker += counter; 
            //move to the next block
        }
        else {
            //if our high point is at the end of the array
            counter *= 2;
            pointtracker = 0;
            //start over at a[0], with a doubled counter
        }
    }
    /*
    System.out.print("Final array is: ");
    printArray(a);
    System.out.println('\n');
    */
}//MergeSortNonRec()

我的合并方法如下:

    private static void merge(long[] a, long[] aux, int lo, int mid, int hi) {

    // copy to aux[]
    for (int k = lo; k <= hi; k++) {
        aux[k] = a[k]; 
    }

    // merge back to a[]
    int i = lo, j = mid+1;
    for (int k = lo; k <= hi; k++) {
        if      (i > mid)           a[k] = aux[j++];
        else if (j > hi)            a[k] = aux[i++];
        else if (aux[j] < aux[i])   a[k] = aux[j++];
        else                        a[k] = aux[i++];
    }
}

递归解决方案更加优雅:

    private static void sort(long[] a, long[] aux, int lo, int hi) {
    if (hi <= lo) return;
    int mid = lo + (hi - lo) / 2;
    sort(a, aux, lo, mid);
    sort(a, aux, mid + 1, hi);
    merge(a, aux, lo, mid, hi);
}

public static void MergeSort(long[] a) {
    long[] aux = new long[a.length];
    sort(a, aux, 0, a.length-1);
}

我的问题是运行时。我的教授说过,合并排序的迭代版本,因为我们只输入长度为2 ^ N的数组,应该比非迭代版本运行得更快。但是,我的迭代版本比大集合的递归版本运行得慢。以下是我的时间输出示例:

![runtime]:https://imgur.com/a/bzVuw“排序算法”

如何缩短迭代合并时间?

编辑:我已经弄清楚了。我在while循环之外移动了我的aux实例,这减少了指数级的时间。谢谢大家!

1 个答案:

答案 0 :(得分:0)

  

如何缩短迭代合并时间?

Wiki有一个迭代(自下而上)合并排序的简化示例:

https://en.wikipedia.org/wiki/Merge_sort#Bottom-up_implementation

为了减少时间,只需要一次性分配aux []数组,而不是在每次合并传递上复制数据,而是在每次传递时将引用交换到数组。

        long [] t = a;      // swap references
        a = aux;
        aux = t;

如果数组的大小是2的奇数幂,则需要复制数组一次或交换到位,而不是进行第一次合并传递。

  

迭代合并排序应该比递归合并排序运行得更快

假设两者的合理优化版本,迭代合并排序通常会更快,但相对差异会随着数组大小的增加而减少,因为大部分时间将花费在merge()函数中,这对于迭代和递归合并排序。

有折衷方案。递归版本将推送和弹出长度 - 2或2 *长度 - 2对索引到/从堆栈,而迭代生成索引(可以保存在寄存器中)。看起来在更深层次的递归过程中,递归版本更容易缓存,因为它在数组的一部分上运行,而迭代版本在每次传递时始终在整个数组中运行,但我从来没有看到这种情况导致递归合并排序的整体性能更好。 PC上的大多数缓存都是4种或更多种方式设置关联,因此两行用于输入,一行用于合并过程中的输出。在我的测试中,多线程迭代合并排序比单线程迭代合并排序快得多,因此我测试过的系统上的合并排序不是内存带宽限制。

这是迭代(自下而上)合并排序以及测试程序的一个稍微优化的示例:

package jsortbu;
import java.util.Random;

public class jsortbu {
    static void MergeSort(int[] a)          // entry function
    {
        if(a.length < 2)                    // if size < 2 return
            return;
        int[] b = new int[a.length];
        BottomUpMergeSort(a, b);
    }

    static void BottomUpMergeSort(int[] a, int[] b)
    {
    int n = a.length;
    int s = 1;                              // run size 
        if(1 == (GetPassCount(n)&1)){       // if odd number of passes
            for(s = 1; s < n; s += 2)       // swap in place for 1st pass
                if(a[s] < a[s-1]){
                    int t = a[s];
                    a[s] = a[s-1];
                    a[s-1] = t;
                }
            s = 2;
        }
        while(s < n){                       // while not done
            int ee = 0;                     // reset end index
            while(ee < n){                  // merge pairs of runs
                int ll = ee;                // ll = start of left  run
                int rr = ll+s;              // rr = start of right run
                if(rr >= n){                // if only left run
                    do                      //   copy it
                        b[ll] = a[ll];
                    while(++ll < n);
                    break;                  //   end of pass
                }
                ee = rr+s;                  // ee = end of right run
                if(ee > n)
                    ee = n;
                Merge(a, b, ll, rr, ee);
            }
            {                               // swap references
                int[] t = a;
                a = b;
                b = t;
            }
            s <<= 1;                        // double the run size
        }
    }

    static void Merge(int[] a, int[] b, int ll, int rr, int ee) {
        int o = ll;                         // b[]       index
        int l = ll;                         // a[] left  index
        int r = rr;                         // a[] right index
        while(true){                        // merge data
            if(a[l] <= a[r]){               // if a[l] <= a[r]
                b[o++] = a[l++];            //   copy a[l]
                if(l < rr)                  //   if not end of left run
                    continue;               //     continue (back to while)
                do                          //   else copy rest of right run
                    b[o++] = a[r++];
                while(r < ee);
                break;                      //     and return
            } else {                        // else a[l] > a[r]
                b[o++] = a[r++];            //   copy a[r]
                if(r < ee)                  //   if not end of right run
                    continue;               //     continue (back to while)
                do                          //   else copy rest of left run
                    b[o++] = a[l++];
                while(l < rr);
                break;                      //     and return
            }
        }
    }

    static int GetPassCount(int n)          // return # passes
    {
        int i = 0;
        for(int s = 1; s < n; s <<= 1)
            i += 1;
        return(i);
    }

    public static void main(String[] args) {
        int[] a = new int[10000000];
        Random r = new Random();
        for(int i = 0; i < a.length; i++)
            a[i] = r.nextInt();
        long bgn, end;
        bgn = System.currentTimeMillis();
        MergeSort(a);
        end = System.currentTimeMillis();
        for(int i = 1; i < a.length; i++){
            if(a[i-1] > a[i]){
                System.out.println("failed");
                break;
            }
        }
        System.out.println("milliseconds " + (end-bgn));
    }
}