包含全零的子系统的数量

时间:2014-04-15 16:11:49

标签: matrix dynamic-programming

有没有办法找到一些矩形子矩阵,其中包含复杂度小于O(n ^ 3)的全零,其中n是给定矩阵的维数?

1 个答案:

答案 0 :(得分:0)

这是一个解决方案 O(n²logn)

首先,让我们将主要问题转换为这样的事情:

对于给定的histogram,找到包含全零的子矩阵的数量。

如何转换?

对于每个位置,计算从该位置开始并仅包含零的列的高度。

示例:

10010    01101
00111    12000
00001 -> 23110
01101    30020
01110    40001

可以在O(n²)中轻松找到。

for(int i = 1; i <= n; i++)
    for(int j = 1; j <= m; j++)
        up[i][j] = arr[i][j] ? 0 : 1 + up[i - 1][j];

现在我们可以将每一行视为具有给定高度的直方图。

让我们用直方图解决问题。

我们的目标是从左到右行进所有高度,在每一步我们将更新阵列L. 每个高度的此数组将包含最大宽度,以便我们可以从当前位置,左侧和给定高度制作此宽度的矩形。

考虑例子:

0
0   0
0 000
00000   -> heights: 6 3 4 4 5 2
000000
000000

L[6]:   1     0     0     0     0     0
L[5]:   1     0     0     0     1     0
L[4]:   1     0     1     2     3     0
L[3]:   1     2     3     4     5     0
L[2]:   1     2     3     4     5     6
L[1]:   1     2     3     4     5     6
steps:  1     2     3     4     5     6

正如您所看到的,如果我们添加所有这些数字,我们将收到给定直方图的答案。

我们可以简单地更新O(n)中的数组L,但是我们也可以在O(log n)中使用可以在区间中添加的段树(具有延迟传播),在区间中设置值并从中获取总和间隔。

在每一步中,我们只在区间[1,高度]中加1,在区间[高度+ 1,高度高度]中设置0,并从区间[1,maxHeight]得到总和。

height - 柱状图中当前列的高度。

maxHeight - 直方图中列的最大高度。

这就是你如何获得O(n²* log n)解决方案:)

以下是C ++中的主要代码:

const int MAXN = 1000;
int n;
int arr[MAXN + 5][MAXN + 5]; // stores given matrix
int up[MAXN + 5][MAXN + 5]; // heights of columns of zeros
long long answer;

long long calculate(int *h, int maxh) { // solve it for histogram
    clearTree();

    long long result = 0;
    for(int i = 1; i <= n; i++) {
        add(1, h[i]); // add 1 to [1, h[i]]
        set(h[i] + 1, maxh); // set 0 in [h[i] + 1, maxh];
        result += query(); // get sum from [1, maxh]
    }

    return result;
}
int main() {
    ios_base::sync_with_stdio(0);
    cin >> n;

    for(int i = 1; i <= n; i++)
        for(int j = 1; j <= n; j++)
            cin >> arr[i][j]; // read the data

    for(int i = 1; i <= n; i++)
        for(int j = 1; j <= n; j++)
            up[i][j] = arr[i][j] ? 0 : 1 + up[i - 1][j]; // calculate values of up

    for(int i = 1; i <= n; i++)
        answer += calculate(up[i], i); // calculate for each row

    cout << answer << endl;
}

这是代码的开头,分段树:

#include <iostream>
using namespace std;

// interval-interval tree that stores sums

const int p = 11;
int sums[1 << p];
int lazy[1 << p];
int need[1 << p];
const int M = 1 << (p - 1);

void update(int node) {
    if(need[node] == 1) { // add
        sums[node] += lazy[node];
        if(node < M) {
            need[node * 2] = need[node * 2] == 2 ? 2 : 1;
            need[node * 2 + 1] = need[node * 2 + 1] == 2 ? 2 : 1;
            lazy[node * 2] += lazy[node] / 2;
            lazy[node * 2 + 1] += lazy[node] / 2;
        }
    } else if(need[node] == 2) { // set
        sums[node] = lazy[node];
        if(node < M) {
            need[node * 2] = need[node * 2 + 1] = 2;
            lazy[node * 2] = lazy[node] / 2;
            lazy[node * 2 + 1] = lazy[node] / 2;
        }
    }
    need[node] = 0;
    lazy[node] = 0;
}

void insert(int node, int l, int r, int lq, int rq, int value, int id) {
    update(node);
    if(lq <= l && r <= rq) {
        need[node] = id;
        lazy[node] = value * (r - l + 1);
        update(node);
        return;
    }
    int mid = (l + r) / 2;
    if(lq <= mid) insert(node * 2, l, mid, lq, rq, value, id);
    if(mid + 1 <= rq) insert(node * 2 + 1, mid + 1, r, lq, rq, value, id);
    sums[node] = sums[node * 2] + sums[node * 2 + 1];
}


int query() {
    return sums[1]; // we only need to know sum of the whole interval
}

void clearTree() {
    for(int i = 1; i < 1 << p; i++)
        sums[i] = lazy[i] = need[i] = 0;
}

void add(int left, int right) {
    insert(1, 0, M - 1, left, right, 1, 1);
}

void set(int left, int right) {
    insert(1, 0, M - 1, left, right, 0, 2);
}

// end of the tree