所以我在查看Rosetta Code上的合并排序的C示例,我对merge()函数的工作原理有点困惑。我认为这是他们使用的语法,用冒号和我的方式抛弃了我。
void merge (int *a, int n, int m) {
int i, j, k;
int *x = malloc(n * sizeof (int));
for (i = 0, j = m, k = 0; k < n; k++) {
x[k] = j == n ? a[i++]
: i == m ? a[j++]
: a[j] < a[i] ? a[j++]
: a[i++];
}
for (i = 0; i < n; i++) {
a[i] = x[i];
}
free(x);
}
void merge_sort (int *a, int n) {
if (n < 2)
return;
int m = n / 2;
merge_sort(a, m);
merge_sort(a + m, n - m);
merge(a, n, m);
}
merge()函数的for循环到底发生了什么?有人可以解释一下吗?
答案 0 :(得分:0)
阅读评论:
void merge (int *a, int n, int m) {
int i, j, k;
// inefficient: allocating a temporary array with malloc
// once per merge phase!
int *x = malloc(n * sizeof (int));
// merging left and right halfs of a into temporary array x
for (i = 0, j = m, k = 0; k < n; k++) {
x[k] = j == n ? a[i++] // right half exhausted, take from left
: i == m ? a[j++] // left half exhausted, take from right
: a[j] < a[i] ? a[j++] // right element smaller, take that
: a[i++]; // otherwise take left element
}
// copy temporary array back to original array.
for (i = 0; i < n; i++) {
a[i] = x[i];
}
free(x); // free temporary array
}
void merge_sort (int *a, int n) {
if (n < 2)
return;
int m = n / 2;
// inefficient: should not recurse if n == 2
// recurse to sort left half
merge_sort(a, m);
// recurse to sort right half
merge_sort(a + m, n - m);
// merge left half and right half in place (via temp array)
merge(a, n, m);
}
merge
函数的更简单,更高效的版本,只使用了一半的临时空间:
static void merge(int *a, int n, int m) {
int i, j, k;
int *x = malloc(m * sizeof (int));
// copy left half to temporary array
for (i = 0; i < m; i++) {
x[i] = a[i];
}
// merge left and right half
for (i = 0, j = m, k = 0; i < m && j < n; k++) {
a[k] = a[j] < x[i] ? a[j++] : x[i++];
}
// finish copying left half
while (i < m) {
a[k++] = x[i++];
}
}
更快版本的merge_sort
涉及分配大小为x
的临时数组n * sizeof(*a)
并将其传递给递归函数merge_sort1
,该函数使用as调用merge
额外参数也是如此。 merge
中的逻辑也在i
和j
进行了一半的比较:
static void merge(int *a, int n, int m, int *x) {
int i, j, k;
for (i = 0; i < m; i++) {
x[i] = a[i];
}
for (i = 0, j = m, k = 0;;) {
if (a[j] < x[i]) {
a[k++] = a[j++];
if (j >= n) break;
} else {
a[k++] = x[i++];
if (i >= m) return;
}
}
while (i < m) {
a[k++] = x[i++];
}
}
static void merge_sort1(int *a, int n, int *x) {
if (n >= 2) {
int m = n / 2;
if (n > 2) {
merge_sort1(a, m, x);
merge_sort1(a + m, n - m, x);
}
merge(a, n, m, x);
}
}
void merge_sort(int *a, int n) {
if (n < 2)
return;
int *x = malloc(n / 2 * sizeof (int));
merge_sort1(a, n, x);
free(x);
}