使用RangeMap时相交范围

时间:2017-01-01 19:45:38

标签: java guava

我遇到了像

这样的问题

您正在维护对冲基金的交易平台。对冲基金的交易员全天执行交易策略。 为简单起见,我们假设每个交易策略只要运行时就会i磅/分钟。 i可能是否定的。 在一天结束时,您有一个如下所示的日志文件:

  • timestamp_start_1,timestamp_end_1,i_1
  • timestamp_start_2,timestamp_end_2,i_2
  • timestamp_start_3,timestamp_end_3,i_3

每一行代表策略何时开始执行,何时停止,以及收益产生的速率。 写一些代码来返回对冲基金每分钟赚取最高金额的那一天的时间。

示例:

输入:

  • (1,13,400)
  • (10,20,100)

结果:

  • (10,13)500

输入:

  • (12,14,400)
  • (10,20,100)

结果:

  • (12,14)500

输入:

  • (10,20,400)
  • (21,25,100)

结果:

  • (10,20)400

我一直在尝试使用番石榴RangeMap来解决它,但没有明显的方法来交叉重叠的区间。

例如:

private static void method(Record[] array){
    RangeMap<Integer, Integer> rangeMap = TreeRangeMap.create();
    for (Record record : array) {
        rangeMap.put(Range.closed(record.startTime, record.endTime), record.profitRate);
    }
    System.out.println(rangeMap);   
}

public static void main(String[] args) {
    Record[] array = {new Record(1,13,400), new Record(10,20,100)};
    method(array);
}

地图如下: [[1..10] = 400,[10..20] = 100]

有没有办法覆盖重叠行为或可用于解决问题的任何其他数据结构?

2 个答案:

答案 0 :(得分:2)

使用RangeMap.subRangeMap(Range)识别与其他特定范围相交的所有现有范围条目,这样您就可以过滤掉交叉点。

这可能类似于:

void add(RangeMap<Integer, Integer> existing, Range<Integer> range, int add) {
  List<Map.Entry<Range<Integer>, Integer>> overlaps = new ArrayList<>(
      existing.subRangeMap(range).asMapOfRanges().entrySet());
  existing.put(range, add);
  for (Map.Entry<Range, Integer> overlap : overlaps) {
    existing.put(overlap.getKey(), overlap.getValue() + add);
  }
}

答案 1 :(得分:2)

如评论中所述:RangeMap可能不适用于此,因为RangeMap的范围必须是不相交的。

评论中提到了一种解决这个问题的方法:一个可以组合所有范围,生成所有不相交的范围。例如,给定这些范围

  |------------|       :400
           |----------|:100

可以计算其交叉点隐含的所有子范围

  |--------|           :400
           |---|       :500
               |------|:100

在这种情况下,中间范围显然是解决方案。

但总的来说,问题陈述中存在一些不足之处。例如,不完全清楚多个范围是否具有相同的开始时间和/或相同的结束时间。可能与可能的优化相关的事情可能是记录是否被排序&#34;无论如何。

但无论如何,一种通用的方法可能如下:

  • 计算从所有开始时间到从那里开始的记录的映射
  • 计算从所有结束时间到结尾的记录
  • 的映射
  • 遍历所有开始和结束时间(按顺序),跟踪当前时间间隔的累计利润率

(是的,这个 基本上是来自评论的不相交集的生成。但它并没有&#34;构建&#34;一个包含这些信息的数据结构。它只是使用这些信息在飞行中计算最大值。

实现可能如下所示:

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.TreeSet;

class Record
{
    int startTime;
    int endTime;
    int profitRate;

    public Record(int startTime, int endTime, int profitRate)
    {
        this.startTime = startTime;
        this.endTime = endTime;
        this.profitRate = profitRate;
    }

    @Override
    public String toString()
    {
        return "(" + startTime + "..." + endTime + ", " + profitRate + ")";
    }

}

public class MaxRangeFinder
{
    public static void main(String[] args)
    {
        test01();
        test02();
    }

    private static void test01()
    {
        System.out.println("Test case 01:");
        Record[] records = {
            new Record(1,13,400), 
            new Record(10,20,100),
        };
        Record max = computeMax(Arrays.asList(records));
        printNicely(Arrays.asList(records), max);
    }

    private static void test02()
    {
        System.out.println("Test case 02:");
        Record[] records = {
            new Record(1,5,100), 
            new Record(2,6,200),
            new Record(3,4,50),
            new Record(3,4,25),
            new Record(5,8,200),
        };
        Record max = computeMax(Arrays.asList(records));
        printNicely(Arrays.asList(records), max);
    }


    private static Record computeMax(Collection<? extends Record> records)
    {
        // Create mappings from the start times to all records that start 
        // there, and from the end times to the records that end there 
        Map<Integer, List<Record>> recordsByStartTime =
            new LinkedHashMap<Integer, List<Record>>();
        for (Record record : records)
        {
            recordsByStartTime.computeIfAbsent(record.startTime,
                t -> new ArrayList<Record>()).add(record);
        }
        Map<Integer, List<Record>> recordsByEndTime =
            new LinkedHashMap<Integer, List<Record>>();
        for (Record record : records)
        {
            recordsByEndTime.computeIfAbsent(record.endTime,
                t -> new ArrayList<Record>()).add(record);
        }

        // Collect all times where a record starts or ends
        Set<Integer> eventTimes = new TreeSet<Integer>();
        eventTimes.addAll(recordsByStartTime.keySet());
        eventTimes.addAll(recordsByEndTime.keySet());

        // Walk over all events, keeping track of the 
        // starting and ending records
        int accumulatedProfitRate = 0;
        int maxAccumulatedProfitRate = -Integer.MAX_VALUE;
        int maxAccumulatedProfitStartTime = 0;
        int maxAccumulatedProfitEndTime = 0;

        for (Integer eventTime : eventTimes)
        {
            int previousAccumulatedProfitRate = accumulatedProfitRate;

            // Add the profit rate of the starting records
            List<Record> startingRecords = Optional
                .ofNullable(recordsByStartTime.get(eventTime))
                .orElse(Collections.emptyList());
            for (Record startingRecord : startingRecords)
            {
                accumulatedProfitRate += startingRecord.profitRate;
            }

            // Subtract the profit rate of the ending records
            List<Record> endingRecords = Optional
                .ofNullable(recordsByEndTime.get(eventTime))
                .orElse(Collections.emptyList());
            for (Record endingRecord : endingRecords)
            {
                accumulatedProfitRate -= endingRecord.profitRate;
            }

            // Update the information about the maximum, if necessary
            if (accumulatedProfitRate > maxAccumulatedProfitRate)
            {
                maxAccumulatedProfitRate = accumulatedProfitRate;
                maxAccumulatedProfitStartTime = eventTime;
                maxAccumulatedProfitEndTime = eventTime;
            }
            if (previousAccumulatedProfitRate == maxAccumulatedProfitRate &&
                accumulatedProfitRate < previousAccumulatedProfitRate)
            {
                maxAccumulatedProfitEndTime = eventTime;
            }
        }

        return new Record(
            maxAccumulatedProfitStartTime, 
            maxAccumulatedProfitEndTime, 
            maxAccumulatedProfitRate);

    }



    private static void printNicely(
        Collection<? extends Record> records, 
        Record max)
    {
        StringBuilder sb = new StringBuilder();
        int maxEndTime =  Collections.max(records, 
            (r0, r1) -> Integer.compare(r0.endTime, r1.endTime)).endTime;
        for (Record record : records)
        {
            sb.append("     ")
                .append(createString(record, maxEndTime))
                .append("\n");
        }
        sb.append("Max: ").append(createString(max, maxEndTime));
        System.out.println(sb.toString());
    }

    private static String createString(Record record, int maxEndTime)
    {
        StringBuilder sb = new StringBuilder();
        int i = 0;
        while (i < record.startTime)
        {
            sb.append(" ");
            i++;
        }
        sb.append("|");
        while (i < record.endTime)
        {
            sb.append("-");
            i++;
        }
        sb.append("|");
        while (i < maxEndTime)
        {
            sb.append(" ");
            i++;
        }
        sb.append(":").append(record.profitRate);
        return sb.toString();
    }
}

代码中给出的两个测试用例的输出是

Test case 01:
      |------------|       :400
               |----------|:100
Max:           |---|       :500

Test case 02:
      |----|   :100
       |----|  :200
        |-|    :50
        |-|    :25
          |---|:200
Max:      |-|  :400