最近,我遇到了一个编码问题,我想为其找到算法。
Array is 1-based array.
Query types:
1 L R X K : Multiply from L to R with K if it is less than X
2 L R: Print the sum from L to R
INPUT:
6 5
5 4 3 2 6 1
Q1: 1 2 3 5 4
Q2: 2 1 2
Q3: 1 1 3 4 5
Q4: 2 3 4
Q5: 2 4 5
OUTPUT:
21
14
8
Explantion:
Q1: 5 16 12 2 6 1
Q2: Print 21(5+16)
Q3: 5 16 12 2 6 1
Q4: Print 14 (12+2)
Q5: Print 8 (2+6)
通常,这里每个查询都需要O(n)复杂度来解决。但是该算法应该可以在更短的时间内解决它,该怎么办?
答案 0 :(得分:0)
我现在唯一想到的改进(不了解输入数据属性的任何事情)就是使用累加和。
因此有一个数组s[n]
,用于存储输入数据a[n]
的总和...
s[i] = a[0] + a[1] + ... + a[i]
因此,l<=r
中范围O(1)
的计算是这样的:
sum(l,r) = s[r] - s[l-1]
这将使您的第2类查询从O(n)
加速到O(1)
。但是,由于您需要维护sum数组,因此类型1查询会变慢一点。但是复杂度不会改变,因此它们仍为O(n)
,最终复杂度将为O(m.n)
以下是C ++中朴素和改进版本的示例:
const int eod=-1; // end of data
int a[]={ 5,4,3,2,6,1,eod }; // data
int q[]= // queries
{
1,2,3,5,4,
2,1,2,
1,1,3,4,5,
2,3,4,
2,4,5,
eod
};
const int n=(sizeof(a)/sizeof(a[0]))-1;
int i,j,y;
int l,r,x,k;
mm_log->Lines->Add("[naive]");
for (i=0;;)
{
if (q[i]==1) // O(n)
{ i++;
l=q[i]-1; i++;
r=q[i]-1; i++;
x=q[i]; i++;
k=q[i]; i++;
if (l>r){ j=l; l=r; r=j; }
for (y=0,j=l;j<=r;j++)
if (a[j]<x) a[j]*=k;
}
else if (q[i]==2) // O(n)
{ i++;
l=q[i]-1; i++;
r=q[i]-1; i++;
if (l>r){ j=l; l=r; r=j; }
for (y=0,j=l;j<=r;j++) y+=a[j];
mm_log->Lines->Add(y);
}
else break;
}
mm_log->Lines->Add("[cumulative sum]");
int s[n];
for (y=0,j=0;j<n;j++){ y+=a[j]; s[j]=y; } // init cumulative sum
for (i=0;;)
{
if (q[i]==1) // O(n)
{ i++;
l=q[i]-1; i++;
r=q[i]-1; i++;
x=q[i]; i++;
k=q[i]; i++;
if (l>r){ j=l; l=r; r=j; }
for (y=0,j=l;j<=r;j++)
if (a[j]<x) a[j]*=k;
// update cumulative sum
j=0; y=0; if (l) y=s[l-1];
for (j=l;j<n;j++){ y+=a[j]; s[j]=y; }
}
else if (q[i]==2) // O(1)
{ i++;
l=q[i]-1; i++;
r=q[i]-1; i++;
if (l>r){ j=l; l=r; r=j; }
// use cumulative sum
y=s[r]; if (l) y-=s[l-1];
mm_log->Lines->Add(y);
}
else break;
}
只需将mm_log->Lines->Add()
更改为您可以使用的任何打印功能。
此处输出:
[naive]
21
14
8
[cumulative sum]
21
14
8
另一项改进是检查类型为max(l,r)
的类型2查询的O(m)
,因此您可以忽略/剪切a[n]
和s[n]
不需要的元素,从而加快类型1查询的更新过程不超过最大索引。此处更新了此代码:
mm_log->Lines->Add("[cumulative sum]");
int s[n],sn;
for (i=0,sn=-1;;) // check for max index in output ranges O(m)
{
if (q[i]==1) i+=5; // O(1)
else if (q[i]==2) // O(1)
{ i++;
l=q[i]-1; i++;
r=q[i]-1; i++;
if (l>r){ j=l; l=r; r=j; }
if (sn<r) sn=r;
}
else break;
} sn++;
for (y=0,j=0;j<sn;j++){ y+=a[j]; s[j]=y; } // init cumulative sum O(sn)
for (i=0;;) // O(m.sn)
{
if (q[i]==1) // O(sn)
{ i++;
l=q[i]-1; i++;
r=q[i]-1; i++;
x=q[i]; i++;
k=q[i]; i++;
if (l>r){ j=l; l=r; r=j; }
if (r>=sn) r=sn-1;
for (y=0,j=l;j<=r;j++)
if (a[j]<x) a[j]*=k;
// update cumulative sum
j=0; y=0; if (l) y=s[l-1];
for (j=l;j<sn;j++){ y+=a[j]; s[j]=y; }
}
else if (q[i]==2) // O(1)
{ i++;
l=q[i]-1; i++;
r=q[i]-1; i++;
if (l>r){ j=l; l=r; r=j; }
// use cumulative sum
y=s[r]; if (l) y-=s[l-1];
mm_log->Lines->Add(y);
}
else break;
}
导致复杂性O(m.sn)
,其中sn<=n
。如果您为每个类型2查询存储sn
,那么您可以将其作为变量处理,以加快处理速度。
在复杂情况下,n
是输入数组a[]
中的元素数,而m
是查询数。