给定一组间隔,找出包含点的间隔数

时间:2014-04-14 14:20:01

标签: java algorithm

假设您有一组N个间隔(表示为左右坐标)和M个点。对于每个点,P算法应该找到P所属的区间数。

这是我的算法:

1)将间隔的左右坐标分别放在“左”和“右”数组中

2)排序“左”,与“右”同时交换条目

3)给定点P,找到最大i,使得左[i] <= P

4)对于每个j&lt; = i,如果右[j]&gt; = P

,则将结果加1

5)返回结果

Java实现:

import java.util.*;

class Intervals {

    public  static int count(int[] left, int[] right, int point) {
        int k = find(left, point), result = 0;
        for (int i=0; i < k; i++)
            if (point <= right[i]) result++;
        return result;
    }


    private static int find(int[] a, int point) {
        if (point < a[0]) return -1;
        int i = 0;
        while (i < a.length && a[i] <= point) i++;
        return i;
    }

    private static void sort(int[] a, int[] b) {
        sort(a, b, 0, a.length-1);
    }

    private static void sort(int[] left, int[] right, int lo, int hi) {
        if (hi <= lo) return;
        int lt = lo, gt = hi;
        exchange(left, right, lo, lo + (int) (Math.random() * (hi-lo+1)));
        int v = left[lo];
        int i = lo;
        while (i <= gt) {
            if      (left[i] < v) exchange(left, right, lt++, i++);
            else if (left[i] > v) exchange(left, right, i, gt--);
            else                  i++;
        }
        sort(left, right, lo, lt-1);
        sort(left, right, gt+1, hi);
    }

    private static void exchange(int[] left, int[] right, int i, int j) {
        int temp  = left[i];
        left[i]   = left[j];
        left[j]   = temp;
        temp      = right[i];
        right[i]  = right[j];
        right[j]  = temp;
    }

    private static boolean less(int[] a, int i, int j) {
        return a[i] < a[j];
    }


    public static void main(String[] args) {
        int n       = Integer.parseInt(args[0]);
        int m       = Integer.parseInt(args[1]);
        int[] left  = new int[n];
        int[] right = new int[n];
        Random r    = new Random();
        int MAX     = 100000;
        for (int i = 0; i < n; i++) {
            left[i] = r.nextInt(MAX);
            right[i] = left[i] + r.nextInt(MAX/4);
        }
        sort(left, right);
        for (int i=0; i < m; i++)
            System.out.println(count(left, right, r.nextInt(MAX)));
    }
}

此代码尚未通过某些测试,我正在尝试查找错误。重点是我实际上不知道在这些测试中使用了什么输入数据。

感谢。

3 个答案:

答案 0 :(得分:1)

可能不是您正在寻找的答案,但可能是另一天遇到此问题的人的答案。

如果您计划经常查询一组相当静态的范围,那么您可能希望考虑Interval Tree

public class IntervalTree<T extends IntervalTree.Interval> {
  // My intervals.

  private final List<T> intervals;
  // My center value. All my intervals contain this center.
  private final long center;
  // My interval range.
  private final long lBound;
  private final long uBound;
  // My left tree. All intervals that end below my center.
  private final IntervalTree<T> left;
  // My right tree. All intervals that start above my center.
  private final IntervalTree<T> right;

  public IntervalTree(List<T> intervals) {
    if (intervals == null) {
      throw new NullPointerException();
    }

    // Initially, my root contains all intervals.
    this.intervals = intervals;

    // Find my center.
    center = findCenter();

    /*
     * Builds lefts out of all intervals that end below my center.
     * Builds rights out of all intervals that start above my center.
     * What remains contains all the intervals that contain my center.
     */

    // Lefts contains all intervals that end below my center point.
    final List<T> lefts = new ArrayList<T>();
    // Rights contains all intervals that start above my center point.
    final List<T> rights = new ArrayList<T>();

    long uB = Long.MIN_VALUE;
    long lB = Long.MAX_VALUE;
    for (T i : intervals) {
      long start = i.getStart();
      long end = i.getEnd();
      if (end < center) {
        lefts.add(i);
      } else if (start > center) {
        rights.add(i);
      } else {
        // One of mine.
        lB = Math.min(lB, start);
        uB = Math.max(uB, end);
      }
    }

    // Remove all those not mine.
    intervals.removeAll(lefts);
    intervals.removeAll(rights);
    uBound = uB;
    lBound = lB;

    // Build the subtrees.
    left = lefts.size() > 0 ? new IntervalTree<T>(lefts) : null;
    right = rights.size() > 0 ? new IntervalTree<T>(rights) : null;

    // Build my ascending and descending arrays.
    /**
     * @todo Build my ascending and descending arrays.
     */
  }

  /*
   * Returns a list of all intervals containing the point.
   */
  List<T> query(long point) {
    // Check my range.
    if (point >= lBound) {
      if (point <= uBound) {
        // In my range but remember, there may also be contributors from left or right.
        List<T> found = new ArrayList<T>();
        // Gather all intersecting ones.
        // Could be made faster (perhaps) by holding two sorted lists by start and end.
        for (T i : intervals) {
          if (i.getStart() <= point && point <= i.getEnd()) {
            found.add(i);
          }
        }

        // Gather others.
        if (point < center && left != null) {
          found.addAll(left.query(point));
        }
        if (point > center && right != null) {
          found.addAll(right.query(point));
        }

        return found;
      } else {
        // To right.
        return right != null ? right.query(point) : Collections.<T>emptyList();
      }
    } else {
      // To left.
      return left != null ? left.query(point) : Collections.<T>emptyList();
    }

  }

  private long findCenter() {
    //return average();
    return median();
  }

  /**
   * @deprecated Causes obscure issues.
   * @return long
   */
  @Deprecated
  protected long average() {
    // Can leave strange (empty) nodes because the average could be in a gap but much quicker.
    // Don't use.
    long value = 0;
    for (T i : intervals) {
      value += i.getStart();
      value += i.getEnd();
    }
    return intervals.size() > 0 ? value / (intervals.size() * 2) : 0;
  }

  protected long median() {
    // Choose the median of all centers. Could choose just ends etc or anything.
    long[] points = new long[intervals.size()];
    int x = 0;
    for (T i : intervals) {
      // Take the mid point.
      points[x++] = (i.getStart() + i.getEnd()) / 2;
    }
    Arrays.sort(points);
    return points[points.length / 2];
  }

  void dump() {
    dump(0);
  }

  private void dump(int level) {
    LogFile log = LogFile.getLog();
    if (left != null) {
      left.dump(level + 1);
    }
    String indent = "|" + StringUtils.spaces(level);
    log.finer(indent + "Bounds:- {" + lBound + "," + uBound + "}");
    for (int i = 0; i < intervals.size(); i++) {
      log.finer(indent + "- " + intervals.get(i));
    }
    if (right != null) {
      right.dump(level + 1);
    }

  }

  /*
   * What an interval looks like.
   */
  public interface Interval {

    public long getStart();

    public long getEnd();
  }

  /*
   * A simple implemementation of an interval.
   */
  public static class SimpleInterval implements Interval {

    private final long start;
    private final long end;

    public SimpleInterval(long start, long end) {
      this.start = start;
      this.end = end;
    }

    public long getStart() {
      return start;
    }

    public long getEnd() {
      return end;
    }

    @Override
    public String toString() {
      return "{" + start + "," + end + "}";
    }
  }

  /**
   * Not called by App, so you will have to call this directly.
   * 
   * @param args 
   */
  public static void main(String[] args) {
    /**
     * @todo Needs MUCH more rigorous testing.
     */
    // Test data.
    long[][] data = {
      {1, 2},
      {2, 9},
      {4, 8},
      {3, 5},
      {7, 9},};
    List<Interval> intervals = new ArrayList<Interval>();
    for (long[] pair : data) {
      intervals.add(new SimpleInterval(pair[0], pair[1]));
    }
    // Build it.
    IntervalTree<Interval> test = new IntervalTree<Interval>(intervals);

    // Test it.
    System.out.println("Normal test: ---");
    for (long i = 0; i < 10; i++) {
      List<Interval> intersects = test.query(i);
      System.out.println("Point " + i + " intersects:");
      for (Interval t : intersects) {
        System.out.println(t.toString());
      }
    }

    // Check for empty list.
    intervals.clear();
    test = new IntervalTree<Interval>(intervals);
    // Test it.
    System.out.println("Empty test: ---");
    for (long i = 0; i < 10; i++) {
      List<Interval> intersects = test.query(i);
      System.out.println("Point " + i + " intersects:");
      for (Interval t : intersects) {
        System.out.println(t.toString());
      }
    }

  }
}

答案 1 :(得分:0)

您的find()方法不是返回一个索引而不是您想要的索引吗?

您返回导致循环退出的i。因此,i == a.lengtha[i] > point

您应该返回i-1。此外,这将是足够通用的,你不必处理特殊情况a[0] > point,顺便打破一个空数组。

private static int find(int[] a, int point) {
    int i = 0;
    while (i < a.length && a[i] <= point) i++;
    return i-1;
}

如果你真的想要进一步返回一个索引,那么你不应该在特殊情况下返回-1,而是0。这也删除了额外的代码行:

private static int find(int[] a, int point) {
    int i = 0;
    while (i < a.length && a[i] <= point) i++;
    return i;
}

答案 2 :(得分:0)

如果对一组间隔端点(左右端点)进行排序,然后从左到右处理间隔端点(跟踪它们所属的间隔),那么在每对连续端点之间可以记录在两个端点之间重叠该子区间的间隔数(每次遇到左端点时计数增加+1,每次遇到右端点时减少计数-1)。然后,给定一个查询点,您只需对端点数组进行二进制搜索,以找到其子区间包含查询点的两个端点,并报告先前计算的包含子区间的间隔数。对于N个区间和P个查询点,总运行时间为O(N log N + P log N)。存储是O(N)。