O(n)最坏情况时间的2D峰值发现算法?

时间:2014-04-16 21:15:37

标签: arrays algorithm data-structures multidimensional-array language-agnostic

我在麻省理工学院做过this课程。在第一堂课中,教授提出了以下问题: -

2D阵列中的峰值是一个值,使得它的4个邻居都小于或等于它,即。为了

a[i][j]是本地最大值,

a[i+1][j] <= a[i][j] 
&& a[i-1][j] <= a[i][j]
&& a[i][j+1] <= a[i][j]
&& a[i+1][j-1] <= a[i][j]

现在给定一个NxN 2D阵列,在数组中找到一个峰值

通过遍历所有元素并返回峰值,可以在O(N^2)时间内轻松解决此问题。

然而,如here所述,可以通过使用分而治之的解决方案在O(NlogN)时间对其进行优化以便解决。

但他们已经说过存在O(N)时间算法来解决这个问题。请建议我们如何在O(N)时间内解决此问题。

PS(对于那些了解python的人)课程工作人员已经解释了一种方法here(问题1-5。峰值证明),并在他们的问题集中提供了一些python代码。但解释的方法完全不明显,很难破译。 python代码同样令人困惑。所以我已经为那些了解python的人复制了下面代码的主要部分,并且可以从代码中告诉我们使用了什么算法。

def algorithm4(problem, bestSeen = None, rowSplit = True, trace = None):
    # if it's empty, we're done 
    if problem.numRow <= 0 or problem.numCol <= 0:
        return None

    subproblems = []
    divider = []

    if rowSplit:
        # the recursive subproblem will involve half the number of rows
        mid = problem.numRow // 2

        # information about the two subproblems
        (subStartR1, subNumR1) = (0, mid)
        (subStartR2, subNumR2) = (mid + 1, problem.numRow - (mid + 1))
        (subStartC, subNumC) = (0, problem.numCol)

        subproblems.append((subStartR1, subStartC, subNumR1, subNumC))
        subproblems.append((subStartR2, subStartC, subNumR2, subNumC))

        # get a list of all locations in the dividing column
        divider = crossProduct([mid], range(problem.numCol))
    else:
        # the recursive subproblem will involve half the number of columns
        mid = problem.numCol // 2

        # information about the two subproblems
        (subStartR, subNumR) = (0, problem.numRow)
        (subStartC1, subNumC1) = (0, mid)
        (subStartC2, subNumC2) = (mid + 1, problem.numCol - (mid + 1))

        subproblems.append((subStartR, subStartC1, subNumR, subNumC1))
        subproblems.append((subStartR, subStartC2, subNumR, subNumC2))

        # get a list of all locations in the dividing column
        divider = crossProduct(range(problem.numRow), [mid])

    # find the maximum in the dividing row or column
    bestLoc = problem.getMaximum(divider, trace)
    neighbor = problem.getBetterNeighbor(bestLoc, trace)

    # update the best we've seen so far based on this new maximum
    if bestSeen is None or problem.get(neighbor) > problem.get(bestSeen):
        bestSeen = neighbor
        if not trace is None: trace.setBestSeen(bestSeen)

    # return when we know we've found a peak
    if neighbor == bestLoc and problem.get(bestLoc) >= problem.get(bestSeen):
        if not trace is None: trace.foundPeak(bestLoc)
        return bestLoc

    # figure out which subproblem contains the largest number we've seen so
    # far, and recurse, alternating between splitting on rows and splitting
    # on columns
    sub = problem.getSubproblemContaining(subproblems, bestSeen)
    newBest = sub.getLocationInSelf(problem, bestSeen)
    if not trace is None: trace.setProblemDimensions(sub)
    result = algorithm4(sub, newBest, not rowSplit, trace)
    return problem.getLocationInSelf(sub, result)

#Helper Method
def crossProduct(list1, list2):
    """
    Returns all pairs with one item from the first list and one item from 
    the second list.  (Cartesian product of the two lists.)

    The code is equivalent to the following list comprehension:
        return [(a, b) for a in list1 for b in list2]
    but for easier reading and analysis, we have included more explicit code.
    """

    answer = []
    for a in list1:
        for b in list2:
            answer.append ((a, b))
    return answer

3 个答案:

答案 0 :(得分:7)

  1. 假设数组的宽度大于高度,否则我们会分开另一个方向。
  2. 将阵列分成三部分:中央列,左侧和右侧。
  3. 浏览中心列和两个相邻列并查找最大值。
    • 如果它在中央栏 - 这是我们的最高峰
    • 如果它在左侧,请在子阵列left_side + central_column上运行此算法
    • 如果它位于右侧,请在子阵列right_side + central_column上运行此算法
  4. 为什么会这样:

    对于最大元素位于中心列的情况 - 显而易见。如果不是,我们可以从最大值逐步增加到增加的元素,并且绝对不会越过中心行,因此相应的一半肯定存在峰值。

    为什么这是O(n):

    步骤#3小于或等于max_dimension次迭代,max_dimension每两个算法步骤至少减半。这样会n+n/2+n/4+...O(n)。重要细节:我们按最大方向划分。对于方形阵列,这意味着分割方向将是交替的。这与您链接到的PDF中的最后一次尝试有所不同。

    注意:我不确定它是否与您提供的代码中的算法完全匹配,它可能是也可能不是一种不同的方法。

答案 1 :(得分:1)

实现@ maxim1000算法的

Here is the working Java code。以下代码在线性时间内在2D数组中找到峰值。

import java.util.*;

class Ideone{
    public static void main (String[] args) throws java.lang.Exception{
        new Ideone().run();
    }
    int N , M ;

    void run(){
        N = 1000;
        M = 100;

        // arr is a random NxM array
        int[][] arr = randomArray();
        long start = System.currentTimeMillis();
//      for(int i=0; i<N; i++){   // TO print the array. 
//          System. out.println(Arrays.toString(arr[i]));
//      }
        System.out.println(findPeakLinearTime(arr));
        long end = System.currentTimeMillis();
        System.out.println("time taken : " + (end-start));
    }

    int findPeakLinearTime(int[][] arr){
        int rows = arr.length;
        int cols = arr[0].length;
        return kthLinearColumn(arr, 0, cols-1, 0, rows-1);
    }

    // helper function that splits on the middle Column
    int kthLinearColumn(int[][] arr, int loCol, int hiCol, int loRow, int hiRow){
        if(loCol==hiCol){
            int max = arr[loRow][loCol];
            int foundRow = loRow;
            for(int row = loRow; row<=hiRow; row++){
                if(max < arr[row][loCol]){
                    max = arr[row][loCol];
                    foundRow = row;
                }
            }
            if(!correctPeak(arr, foundRow, loCol)){
                System.out.println("THIS PEAK IS WRONG");
            }
            return max;
        }
        int midCol = (loCol+hiCol)/2;
        int max = arr[loRow][loCol];
        for(int row=loRow; row<=hiRow; row++){
            max = Math.max(max, arr[row][midCol]);
        }
        boolean centralMax = true;
        boolean rightMax = false;
        boolean leftMax  = false;

        if(midCol-1 >= 0){
            for(int row = loRow; row<=hiRow; row++){
                if(arr[row][midCol-1] > max){
                    max = arr[row][midCol-1];
                    centralMax = false;
                    leftMax = true;
                }
            }
        }

        if(midCol+1 < M){
            for(int row=loRow; row<=hiRow; row++){
                if(arr[row][midCol+1] > max){
                    max = arr[row][midCol+1];
                    centralMax = false;
                    leftMax = false;
                    rightMax = true;
                }
            }
        }

        if(centralMax) return max;
        if(rightMax)  return kthLinearRow(arr, midCol+1, hiCol, loRow, hiRow);
        if(leftMax)   return kthLinearRow(arr, loCol, midCol-1, loRow, hiRow);
        throw new RuntimeException("INCORRECT CODE");
    }

    // helper function that splits on the middle 
    int kthLinearRow(int[][] arr, int loCol, int hiCol, int loRow, int hiRow){
        if(loRow==hiRow){
            int ans = arr[loCol][loRow];
            int foundCol = loCol;
            for(int col=loCol; col<=hiCol; col++){
                if(arr[loRow][col] > ans){
                    ans = arr[loRow][col];
                    foundCol = col;
                }
            }
            if(!correctPeak(arr, loRow, foundCol)){
                System.out.println("THIS PEAK IS WRONG");
            }
            return ans;
        }
        boolean centralMax = true;
        boolean upperMax = false;
        boolean lowerMax = false;

        int midRow = (loRow+hiRow)/2;
        int max = arr[midRow][loCol];

        for(int col=loCol; col<=hiCol; col++){
            max = Math.max(max, arr[midRow][col]);
        }

        if(midRow-1>=0){
            for(int col=loCol; col<=hiCol; col++){
                if(arr[midRow-1][col] > max){
                    max = arr[midRow-1][col];
                    upperMax = true;
                    centralMax = false;
                }
            }
        }

        if(midRow+1<N){
            for(int col=loCol; col<=hiCol; col++){
                if(arr[midRow+1][col] > max){
                    max = arr[midRow+1][col];
                    lowerMax = true;
                    centralMax = false;
                    upperMax   = false;
                }
            }
        }

        if(centralMax) return max;
        if(lowerMax)   return kthLinearColumn(arr, loCol, hiCol, midRow+1, hiRow);
        if(upperMax)   return kthLinearColumn(arr, loCol, hiCol, loRow, midRow-1);
        throw new RuntimeException("Incorrect code");
    }

    int[][] randomArray(){
        int[][] arr = new int[N][M];
        for(int i=0; i<N; i++)
            for(int j=0; j<M; j++)
                arr[i][j] = (int)(Math.random()*1000000000);
        return arr;
    }

    boolean correctPeak(int[][] arr, int row, int col){//Function that checks if arr[row][col] is a peak or not
        if(row-1>=0 && arr[row-1][col]>arr[row][col])  return false;
        if(row+1<N && arr[row+1][col]>arr[row][col])   return false;
        if(col-1>=0 && arr[row][col-1]>arr[row][col])  return false;
        if(col+1<M && arr[row][col+1]>arr[row][col])   return false;
        return true;
    }
}

答案 2 :(得分:0)

要查看那个(n):

Calculation step is in the picture

查看算法实现:

1)从1a)或1b)

开始

1a)设置左半部分,分隔线,右半部分。

1b)设置上半部分,分隔线,下半部分。

2)在除法器上找到全局最大值。 [theta n]

3)查找其邻居的值。并将有史以来访问过的最大节点记录为bestSeen节点。 [theta 1]

'GREATER THAN'

4)检查全局最大值是否大于bestSeen及其邻居。 [theta 1]

//第4步是该算法起作用的主要关键

Mumbai AND Pune are big cities in Maharashtra. Mumbai GREATER THAN Pune.

5)如果4)为True,则将全局最大值返回为2-D峰值。

否则,如果这次是1a),请选择BestSeen的一半,然后返回步骤1b)

否则,选择BestSeen的一半,回到步骤1a)


要直观地了解此算法的工作原理,就好像抓住了价值最大化的一面,不断缩小界限,最终获得了BestSeen的价值。

#可视化模拟

round1

round2

round3

round4

round5

round6

finally

对于这个10 * 10矩阵,我们仅用了6个步骤来搜索2D峰,这非常有说服力,它确实是theta n


猎鹰人