有限制的长度L的子序列的最大和

时间:2019-03-26 13:26:12

标签: python algorithm dynamic-programming

给出一个正整数数组。如何找到总和为L的长度max的子序列,子序列的任意两个相邻元素之间的距离不超过K

我有以下解决方案,但不知道如何考虑长度L。

1 <= N <= 100000,1 <= L <= 200,1 <= K <= N

f [i]包含以i结尾的子序列的最大和。

for i in range(K, N)
    f[i] = INT_MIN
    for j in range(1, K+1)
        f[i] = max(f[i], f[i-j] + a[i])
return max(f)

4 个答案:

答案 0 :(得分:5)

(编辑:稍微简化的非递归解决方案)

您可以这样操作,仅在每次迭代时考虑是否应包含或排除该项。

def f(maxK,K, N, L, S):
    if L == 0 or not N or K == 0:
        return S
    #either element is included
    included = f(maxK,maxK, N[1:], L-1, S + N[0]  )
    #or excluded
    excluded = f(maxK,K-1, N[1:], L, S )
    return max(included, excluded)


assert f(2,2,[10,1,1,1,1,10],3,0) == 12
assert f(3,3,[8, 3, 7, 6, 2, 1, 9, 2, 5, 4],4,0) == 30

如果N很长,您可以考虑更改为表格版本,也可以将输入更改为元组并使用备注。

由于OP后来包含了N可以为100 000的信息,因此我们不能真正使用像这样的递归解决方案。因此,这是在O(n K L)中运行的解决方案,具有相同的内存要求:

import numpy as np

def f(n,K,L):
    t = np.zeros((len(n),L+1))

    for l in range(1,L+1):
        for i in range(len(n)):
            t[i,l] = n[i] + max( (t[i-k,l-1] for k in range(1,K+1) if i-k >= 0), default = 0 )

    return np.max(t)


assert f([10,1,1,1,1,10],2,3) == 12
assert f([8, 3, 7, 6, 2, 1, 9],3,4) == 30

非递归解决方案的说明。表t [i,l]中的每个单元格表示max个子序列的值,其中恰好有l个元素使用位置i中的元素,而仅位置i或以下位置中的元素彼此之间最多具有K距离的元素。

长度为n的子序列(t [i,1]中的子序列必须只有一个元素n [i])

更长的子序列具有n [i] +个l-1个元素的子序列,该子序列最多早于k行开始,我们选择一个最大值的子序列。通过这种方式进行迭代,我们确保已计算出该值。

考虑到您只回望最多K步,可以进一步改善内存。

答案 1 :(得分:3)

这是Python中的自底向上(即无递归)动态解决方案。它需要占用内存O(l * n)和时间O(l * n * k)

def max_subseq_sum(k, l, values):
    # table[i][j] will be the highest value from a sequence of length j
    # ending at position i
    table = []
    for i in range(len(values)):
        # We have no sum from 0, and i from len 1.
        table.append([0, values[i]])
        # By length of previous subsequence
        for subseq_len in range(1, l):
            # We look back up to k for the best.
            prev_val = None
            for last_i in range(i-k, i):
                # We don't look back if the sequence was not that long.
                if subseq_len <= last_i+1:
                    # Is this better?
                    this_val = table[last_i][subseq_len]
                    if prev_val is None or prev_val < this_val:
                        prev_val = this_val
            # Do we have a best to offer?
            if prev_val is not None:
                table[i].append(prev_val + values[i])

    # Now we look for the best entry of length l.
    best_val = None
    for row in table:
        # If the row has entries for 0...l will have len > l.
        if l < len(row):
            if best_val is None or best_val < row[l]:
                best_val = row[l]
    return best_val

print(max_subseq_sum(2, 3, [10, 1, 1, 1, 1, 10]))
print(max_subseq_sum(3, 4, [8, 3, 7, 6, 2, 1, 9, 2, 5, 4]))

如果我想稍微聪明一点,我可以通过一次计算一层,而扔掉上一层来使O(n)变得很容易。将运行时间减少到O(l*n*log(k))需要很多技巧,但这是可行的。 (使用优先级队列来获取最后k个值的最佳值。O(log(k))用于更新每个元素,但会自然增长。每个k值都会丢弃它并为{{1 }}的费用O(k)次,共进行了O(n/k)的重建费用。)

这是一个聪明的版本。内存O(n)。时间O(n)最差,平均情况为O(n*l*log(k))。按升序排序时,情况最糟糕。

O(n*l)

答案 2 :(得分:2)

扩展docs中显示的itertools.combinations的代码,我构建了一个版本,其中包含两个值之间最大索引距离(K)的参数。只需在迭代中另外进行一次and indices[i] - indices[i-1] < K检查:

def combinations_with_max_dist(iterable, r, K):
    # combinations('ABCD', 2) --> AB AC AD BC BD CD
    # combinations(range(4), 3) --> 012 013 023 123
    pool = tuple(iterable)
    n = len(pool)
    if r > n:
        return
    indices = list(range(r))
    yield tuple(pool[i] for i in indices)
    while True:
        for i in reversed(range(r)):
            if indices[i] != i + n - r and indices[i] - indices[i-1] < K:
                break
        else:
            return               
        indices[i] += 1        
        for j in range(i+1, r):
            indices[j] = indices[j-1] + 1
        yield tuple(pool[i] for i in indices)

使用此方法,您可以对K的所有组合进行暴力破解,然后找到具有最大值之和的组合:

def find_subseq(a, L, K):
    return max((sum(values), values) for values in combinations_with_max_dist(a, L, K))

结果:

print(*find_subseq([10, 1, 1, 1, 1, 10], L=3, K=2))
# 12 (10, 1, 1)
print(*find_subseq([8, 3, 7, 6, 2, 1, 9, 2, 5, 4], L=4, K=3))
# 30 (8, 7, 6, 9)

不确定值列表过长的情况下的性能...

答案 3 :(得分:1)


算法

基本思路:

    在输入数组上
  • 迭代,选择每个索引作为第一个采用的元素。
  • 然后在每个第一个采用的元素上递归,将索引标记为firstIdx
    • 下一个可能的索引将在[firstIdx + 1, firstIdx + K]范围内(包括两个端点)。
    • 在范围上循环,以L - 1作为新的L递归调用每个索引。
  • (可选)对于每对(firstIndexL),缓存其最大和,以供重用。 也许这对于大量输入是必需的。

约束

  • array length <= 1 << 17 // 131072
  • K <= 1 << 6 // 64
  • L <= 1 << 8 // 256

复杂度:

  • 时间:O(n * L * K)
    由于每个(firstIdx , L)对仅计算一次,并且包含K.
  • 的迭代
  • 空格O(n * L)
    对于缓存,以及递归调用中的方法堆栈。

提示:

  • 递归深度与L相关, 不是 array length
  • 定义的约束不是实际的限制,它可能会更大,尽管我没有测试它的大小。
    基本上:
    • array lengthK实际上都可以是任意大小,只要有足够的内存即可,因为它们是通过迭代处理的。
    • L是通过递归处理的,因此确实有限制。

代码-在Java

SubSumLimitedDistance.java:

import java.util.HashMap;
import java.util.Map;

public class SubSumLimitedDistance {
    public static final long NOT_ENOUGH_ELE = -1; // sum that indicate not enough element, should be < 0,
    public static final int MAX_ARR_LEN = 1 << 17; // max length of input array,
    public static final int MAX_K = 1 << 6; // max K, should not be too long, otherwise slow,
    public static final int MAX_L = 1 << 8; // max L, should not be too long, otherwise stackoverflow,

    /**
     * Find max sum of sum array.
     *
     * @param arr
     * @param K
     * @param L
     * @return max sum,
     */
    public static long find(int[] arr, int K, int L) {
        if (K < 1 || K > MAX_K)
            throw new IllegalArgumentException("K should be between [1, " + MAX_K + "], but get: " + K);
        if (L < 0 || L > MAX_L)
            throw new IllegalArgumentException("L should be between [0, " + MAX_L + "], but get: " + L);
        if (arr.length > MAX_ARR_LEN)
            throw new IllegalArgumentException("input array length should <= " + MAX_ARR_LEN + ", but get: " + arr.length);

        Map<Integer, Map<Integer, Long>> cache = new HashMap<>(); // cache,

        long maxSum = NOT_ENOUGH_ELE;
        for (int i = 0; i < arr.length; i++) {
            long sum = findTakeFirst(arr, K, L, i, cache);
            if (sum == NOT_ENOUGH_ELE) break; // not enough elements,

            if (sum > maxSum) maxSum = sum; // larger found,
        }

        return maxSum;
    }

    /**
     * Find max sum of sum array, with index of first taken element specified,
     *
     * @param arr
     * @param K
     * @param L
     * @param firstIdx index of first taken element,
     * @param cache
     * @return max sum,
     */
    private static long findTakeFirst(int[] arr, int K, int L, int firstIdx, Map<Integer, Map<Integer, Long>> cache) {
        // System.out.printf("findTakeFirst(): K = %d, L = %d, firstIdx = %d\n", K, L, firstIdx);
        if (L == 0) return 0; // done,
        if (firstIdx + L > arr.length) return NOT_ENOUGH_ELE; // not enough elements,

        // check cache,
        Map<Integer, Long> map = cache.get(firstIdx);
        Long cachedResult;
        if (map != null && (cachedResult = map.get(L)) != null) {
            // System.out.printf("hit cache, cached result = %d\n", cachedResult);
            return cachedResult;
        }

        // cache not exists, calculate,
        long maxRemainSum = NOT_ENOUGH_ELE;
        for (int i = firstIdx + 1; i <= firstIdx + K; i++) {
            long remainSum = findTakeFirst(arr, K, L - 1, i, cache);
            if (remainSum == NOT_ENOUGH_ELE) break; // not enough elements,
            if (remainSum > maxRemainSum) maxRemainSum = remainSum;
        }

        if ((map = cache.get(firstIdx)) == null) cache.put(firstIdx, map = new HashMap<>());

        if (maxRemainSum == NOT_ENOUGH_ELE) { // not enough elements,
            map.put(L, NOT_ENOUGH_ELE); // cache - as not enough elements,
            return NOT_ENOUGH_ELE;
        }

        long maxSum = arr[firstIdx] + maxRemainSum; // max sum,
        map.put(L, maxSum); // cache - max sum,

        return maxSum;
    }
}

SubSumLimitedDistanceTest.java:
(测试用例,通过TestNG

import org.testng.Assert;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

import java.util.concurrent.ThreadLocalRandom;

public class SubSumLimitedDistanceTest {
    private int[] arr;
    private int K;
    private int L;
    private int maxSum;

    private int[] arr2;
    private int K2;
    private int L2;
    private int maxSum2;

    private int[] arrMax;
    private int KMax;
    private int KMaxLargest;

    private int LMax;
    private int LMaxLargest;

    @BeforeClass
    private void setUp() {
        // init - arr,
        arr = new int[]{10, 1, 1, 1, 1, 10};
        K = 2;
        L = 3;
        maxSum = 12;

        // init - arr2,
        arr2 = new int[]{8, 3, 7, 6, 2, 1, 9, 2, 5, 4};
        K2 = 3;
        L2 = 4;
        maxSum2 = 30;

        // init - arrMax,
        arrMax = new int[SubSumLimitedDistance.MAX_ARR_LEN];
        ThreadLocalRandom rd = ThreadLocalRandom.current();
        long maxLongEle = Long.MAX_VALUE / SubSumLimitedDistance.MAX_ARR_LEN;
        int maxEle = maxLongEle > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) maxLongEle;
        for (int i = 0; i < arrMax.length; i++) {
            arrMax[i] = rd.nextInt(maxEle);
        }

        KMax = 5;
        LMax = 10;

        KMaxLargest = SubSumLimitedDistance.MAX_K;
        LMaxLargest = SubSumLimitedDistance.MAX_L;
    }

    @Test
    public void test() {
        Assert.assertEquals(SubSumLimitedDistance.find(arr, K, L), maxSum);
        Assert.assertEquals(SubSumLimitedDistance.find(arr2, K2, L2), maxSum2);
    }

    @Test(timeOut = 6000)
    public void test_veryLargeArray() {
        run_printDuring(arrMax, KMax, LMax);
    }

    @Test(timeOut = 60000) // takes seconds,
    public void test_veryLargeArrayL() {
        run_printDuring(arrMax, KMax, LMaxLargest);
    }

    @Test(timeOut = 60000) // takes seconds,
    public void test_veryLargeArrayK() {
        run_printDuring(arrMax, KMaxLargest, LMax);
    }

    // run find once, and print during,
    private void run_printDuring(int[] arr, int K, int L) {
        long startTime = System.currentTimeMillis();
        long sum = SubSumLimitedDistance.find(arr, K, L);
        long during = System.currentTimeMillis() - startTime; // during in milliseconds,
        System.out.printf("arr length = %5d, K = %3d, L = %4d, max sum = %15d, running time = %.3f seconds\n", arr.length, K, L, sum, during / 1000.0);
    }

    @Test
    public void test_corner_notEnoughEle() {
        Assert.assertEquals(SubSumLimitedDistance.find(new int[]{1}, 2, 3), SubSumLimitedDistance.NOT_ENOUGH_ELE); // not enough element,
        Assert.assertEquals(SubSumLimitedDistance.find(new int[]{0}, 1, 3), SubSumLimitedDistance.NOT_ENOUGH_ELE); // not enough element,
    }

    @Test
    public void test_corner_ZeroL() {
        Assert.assertEquals(SubSumLimitedDistance.find(new int[]{1, 2, 3}, 2, 0), 0); // L = 0,
        Assert.assertEquals(SubSumLimitedDistance.find(new int[]{0}, 1, 0), 0); // L = 0,
    }

    @Test(expectedExceptions = IllegalArgumentException.class)
    public void test_invalid_K() {
        // SubSumLimitedDistance.find(new int[]{1, 2, 3}, 0, 2); // K = 0,
        // SubSumLimitedDistance.find(new int[]{1, 2, 3}, -1, 2); // K = -1,
        SubSumLimitedDistance.find(new int[]{1, 2, 3}, SubSumLimitedDistance.MAX_K + 1, 2); // K = SubSumLimitedDistance.MAX_K+1,
    }

    @Test(expectedExceptions = IllegalArgumentException.class)
    public void test_invalid_L() {
        // SubSumLimitedDistance.find(new int[]{1, 2, 3}, 2, -1); // L = -1,
        SubSumLimitedDistance.find(new int[]{1, 2, 3}, 2, SubSumLimitedDistance.MAX_L + 1); // L = SubSumLimitedDistance.MAX_L+1,
    }

    @Test(expectedExceptions = IllegalArgumentException.class)
    public void test_invalid_tooLong() {
        SubSumLimitedDistance.find(new int[SubSumLimitedDistance.MAX_ARR_LEN + 1], 2, 3); // input array too long,
    }
}

大量输入的测试用例的输出:

arr length = 131072, K =   5, L =   10, max sum =     20779205738, running time = 0.303 seconds
arr length = 131072, K =  64, L =   10, max sum =     21393422854, running time = 1.917 seconds
arr length = 131072, K =   5, L =  256, max sum =    461698553839, running time = 9.474 seconds