是否有更有效的方法来实现这一目标:
给定大小为A
的数组n
以及两个正整数a
和b
,找到所有对floor(abs(A[i]-A[j])*a/b)
上的总和(i, j)
,其中{ {1}}。
0 <= i < j < n
为了优化这一点,我对数组进行了排序(int A[n];
int a, b; // assigned some positive integer values
...
int total = 0;
for (int i = 0; i < n; i++) {
for (int j = i+1; j < n; j++) {
total += abs(A[i]-A[j])*a/b; // want integer division here
}
}
),然后没有使用O(nlogn)
函数。此外,我在内部for循环之前缓存了值abs
,因此我可以按顺序从a[i]
读取内容。我正在考虑预先计算A
并将其存储在一个浮点数中,但额外的转换只会使它变慢(特别是因为我想取决于结果)。
我无法想出一个比a/b
更好的解决方案。
答案 0 :(得分:2)
是的,有一种更有效的算法。它可以在O(n * log n)中完成。我不希望有一个渐近更快的方式,但我对证据的想法很远。
首先在O(n * log n)时间内对数组进行排序。
现在,让我们看一下术语
floor((A[j]-A[i])*a/b) = floor ((A[j]*a - A[i]*a)/b)
代表0 <= i < j < n
。对于每个0 <= k < n
,请使用A[k]*a = q[k]*b + r[k]
撰写0 <= r[k] < b
。
对于A[k] >= 0
,我们q[k] = (A[k]*a)/b
和r[k] = (A[k]*a)%b
具有整数除法,对于A[k] < 0
,我们有q[k] = (A[k]*a)/b - 1
和r[k] = b + (A[k]*a)%b
,除非{{ 1}}除b
,在这种情况下我们有A[k]*a
和q[k] = (A[k]*a)/b
。
现在我们重写条款:
r[k] = 0
每个floor((A[j]*a - A[i]*a)/b) = floor(q[j] - q[i] + (r[j] - r[i])/b)
= q[j] - q[i] + floor((r[j] - r[i])/b)
显示q[k]
次带有正号码(适用于k
)和i = 0, 1, .. , k-1
次带有负号码(适用于n-1-k
),因此其总贡献总和是
j = k+1, k+2, ..., n-1
剩余部分仍然需要考虑。现在,自(k - (n-1-k))*q[k] = (2*k+1-n)*q[k]
以来,我们已经
0 <= r[k] < b
当-b < r[j] - r[i] < b
时{p>和floor((r[j]-r[i])/b)
为0,r[j] >= r[i]
时-1
为r[j] < r[i]
。所以
n-1
∑ floor((A[j]-A[i])*a/b) = ∑ (2*k+1-n)*q[k] - inversions(r)
i<j k=0
其中反转是(i,j)
和0 <= i < j < n
的一对r[j] < r[i]
索引。
计算q[k]
和r[k]
并将(2*k+1-n)*q[k]
求和在O(n)时间内完成。
仍然有效地计算r[k]
数组的反转。
对于每个索引0 <= k < n
,让c(k)
为i < k
的{{1}},即r[k] < r[i]
出现的倒数的数量,k
更大的指数。
然后显然反转次数为∑ c(k)
。
另一方面,c(k)
是在稳定排序中移动r[k]
后面的元素数量(稳定性在这里很重要)。
计算这些移动,因此在合并排序时很容易进行数组的反转。
因此反转也可以用O(n * log n)计算,给出总体复杂度为O(n * log n)。
一个简单的不科学基准的示例实现(但是天真的二次算法与上面的差异之间的差异是如此之大,以至于不科学的基准是足够的结论)。
#include <stdlib.h>
#include <stdio.h>
#include <time.h>
long long mergesort(int *arr, unsigned elems);
long long merge(int *arr, unsigned elems, int *scratch);
long long nosort(int *arr, unsigned elems, long long a, long long b);
long long withsort(int *arr, unsigned elems, long long a, long long b);
int main(int argc, char *argv[]) {
unsigned count = (argc > 1) ? strtoul(argv[1],NULL,0) : 1000;
srand(time(NULL)+count);
long long a, b;
b = 1000 + 9000.0*rand()/(RAND_MAX+1.0);
a = b/3 + (b-b/3)*1.0*rand()/(RAND_MAX + 1.0);
int *arr1, *arr2;
arr1 = malloc(count*sizeof *arr1);
arr2 = malloc(count*sizeof *arr2);
if (!arr1 || !arr2) {
fprintf(stderr,"Allocation failed\n");
exit(EXIT_FAILURE);
}
unsigned i;
for(i = 0; i < count; ++i) {
arr1[i] = 20000.0*rand()/(RAND_MAX + 1.0) - 2000;
}
for(i = 0; i < count; ++i) {
arr2[i] = arr1[i];
}
long long res1, res2;
double start = clock();
res1 = nosort(arr1,count,a,b);
double stop = clock();
printf("Naive: %lld in %.3fs\n",res1,(stop-start)/CLOCKS_PER_SEC);
start = clock();
res2 = withsort(arr2,count,a,b);
stop = clock();
printf("Sorting: %lld in %.3fs\n",res2,(stop-start)/CLOCKS_PER_SEC);
return EXIT_SUCCESS;
}
long long nosort(int *arr, unsigned elems, long long a, long long b) {
long long total = 0;
unsigned i, j;
long long m;
for(i = 0; i < elems-1; ++i) {
m = arr[i];
for(j = i+1; j < elems; ++j) {
long long d = (arr[j] < m) ? (m-arr[j]) : (arr[j]-m);
total += (d*a)/b;
}
}
return total;
}
long long withsort(int *arr, unsigned elems, long long a, long long b) {
long long total = 0;
unsigned i;
mergesort(arr,elems);
for(i = 0; i < elems; ++i) {
long long q, r;
q = (arr[i]*a)/b;
r = (arr[i]*a)%b;
if (r < 0) {
r += b;
q -= 1;
}
total += (2*i+1LL-elems)*q;
arr[i] = (int)r;
}
total -= mergesort(arr,elems);
return total;
}
long long mergesort(int *arr, unsigned elems) {
if (elems < 2) return 0;
int *scratch = malloc((elems + 1)/2*sizeof *scratch);
if (!scratch) {
fprintf(stderr,"Alloc failure\n");
exit(EXIT_FAILURE);
}
return merge(arr, elems, scratch);
}
long long merge(int *arr, unsigned elems, int *scratch) {
if (elems < 2) return 0;
unsigned left = (elems + 1)/2, right = elems-left, i, j, k;
long long inversions = 0;
inversions += merge(arr, left, scratch);
inversions += merge(arr+left,right,scratch);
if (arr[left] < arr[left-1]) {
for(i = 0; i < left; ++i) {
scratch[i] = arr[i];
}
i = 0; j = 0; k = 0;
int *lptr = scratch, *rptr = arr+left;
while(i < left && j < right) {
if (rptr[j] < lptr[i]) {
arr[k++] = rptr[j++];
inversions += (left-i);
} else {
arr[k++] = lptr[i++];
}
}
while(i < left) arr[k++] = lptr[i++];
}
return inversions;
}