使用Numpy数组而不是list时排序算法中的RecursionError

时间:2018-04-22 19:24:24

标签: python arrays sorting numpy recursion

这是我的情况:我创建一个包含100,000个元素的numpy数组,对数组进行洗牌,然后执行以下三个操作之一:

1)使用合并排序对数组进行排序,再次对数组进行排序,然后尝试使用快速排序进行排序,其中我得到“RecursionError:比较时超出了最大递归深度”

2)使用快速排序对数组进行排序,这非常合适。

3)立即将数组转换为列表并执行步骤1,这不会引发任何错误。

为什么我在合并排序后运行快速排序后才会出现递归错误?

为什么在使用列表而不是Numpy数组时不会出现此错误?

非常感谢您的帮助。

以下是完整代码:

import random
import numpy as np

def quick_sort(ARRAY):
    """Pure implementation of quick sort algorithm in Python
    :param collection: some mutable ordered collection with heterogeneous
    comparable items inside
    :return: the same collection ordered by ascending
    Examples:
    >>> quick_sort([0, 5, 3, 2, 2])
    [0, 2, 2, 3, 5]
    >>> quick_sort([])
    []
    >>> quick_sort([-2, -5, -45])
    [-45, -5, -2]
    """
    ARRAY_LENGTH = len(ARRAY)
    if( ARRAY_LENGTH <= 1):
        return ARRAY
    else:
        PIVOT = ARRAY[0]
        GREATER = [ element for element in ARRAY[1:] if element > PIVOT ]
        LESSER = [ element for element in ARRAY[1:] if element <= PIVOT ]
        return quick_sort(LESSER) + [PIVOT] + quick_sort(GREATER)

def merge_sort(collection):
    """Pure implementation of the merge sort algorithm in Python
    :param collection: some mutable ordered collection with heterogeneous
    comparable items inside
    :return: the same collection ordered by ascending
    Examples:
    >>> merge_sort([0, 5, 3, 2, 2])
    [0, 2, 2, 3, 5]
    >>> merge_sort([])
    []
    >>> merge_sort([-2, -5, -45])
    [-45, -5, -2]
    """
    length = len(collection)
    if length > 1:
        midpoint = length // 2
        left_half = merge_sort(collection[:midpoint])
        right_half = merge_sort(collection[midpoint:])
        i = 0
        j = 0
        k = 0
        left_length = len(left_half)
        right_length = len(right_half)
        while i < left_length and j < right_length:
            if left_half[i] < right_half[j]:
                collection[k] = left_half[i]
                i += 1
            else:
                collection[k] = right_half[j]
                j += 1
            k += 1

        while i < left_length:
            collection[k] = left_half[i]
            i += 1
            k += 1

        while j < right_length:
            collection[k] = right_half[j]
            j += 1
            k += 1

    return collection

def is_sorted(a):
    for n in range(len(a) - 1):
        if a[n] > a[n + 1]:
            return 'not sorted'
    return 'sorted'

# Initialize
list_len = 100000                           # Define list len
print("Set list len to %s" % list_len)
data = np.arange(0, list_len, 1)            # Create array of numbers
# Alternatively: data = list(np.arange(0, list_len, 1))  <-- This WILL NOT cause an error
print("Created array")

# Shuffle
print("Shuffling array")
random.shuffle(data)                        # Shuffle array
print("List: %s" % is_sorted(data))         # Verify that list is not sorted

# Sort (merge sort)
print("Sorting array with merge sort")
merge_sort(data)                            # Sort with merge sort      
print("List: %s" % is_sorted(data))         # Verify that list is sorted

# Shuffle
print("Shuffling array")
random.shuffle(data)                        # Reshuffle list
print("List: %s" % is_sorted(data))         # Verify that list is not sorted

# Sort (quick sort)
print("Sorting array with quick sort")
print(quick_sort(data))                     # Sort with quick sort
print("List: %s" % is_sorted(data))         # Verify that list is sorted

完整的追溯:

Traceback (most recent call last):
File "Untitled 3.py", line 99, in <module>
    print(quick_sort(data))                     # Sort with quick sort
  File "Untitled 3.py", line 24, in quick_sort
    return quick_sort(LESSER) + [PIVOT] + quick_sort(GREATER)
  File "Untitled 3.py", line 24, in quick_sort
    return quick_sort(LESSER) + [PIVOT] + quick_sort(GREATER)
  File "Untitled 3.py", line 24, in quick_sort
    return quick_sort(LESSER) + [PIVOT] + quick_sort(GREATER)
  [Previous line repeated 993 more times]
  File "Untitled 3.py", line 22, in quick_sort
    GREATER = [ element for element in ARRAY[1:] if element > PIVOT ]
  File "Untitled 3.py", line 22, in <listcomp>
    GREATER = [ element for element in ARRAY[1:] if element > PIVOT ]
RecursionError: maximum recursion depth exceeded in comparison

当quicksort尝试对列表进行排序时,显然会发生错误。注意:我知道使用列表会更快,我知道我可以提高递归限制。我知道这可能是由于快速排序已经排序的列表引起的,但我的代码证明这不是正在发生的事情。另外,正如我之前所说的,quicksort本身可以正常工作,所以这不是由无限递归循环引起的。我出于好奇而问这个问题,以便更好地理解它为什么会发生。

1 个答案:

答案 0 :(得分:2)

错误在于merge_sort

numpy数组和列表之间的一个重要区别是前者在切片时返回视图,而后者返回副本。

因此collectionleft_halfright_half在处理数组时都引用相同的数据,而在列表中left_halfright_half将是切片副本。

您可以通过强制复制或写入新分配的输出来解决此问题。

由于这个错误最终会被覆盖,而其他元素会多次出现。事实上,当我进行测试时,有很多零。

这会触发quick_sort中的最坏情况行为:在一个相等元素的块中,递归将一次削减一个,这使得它达到了递归限制。

我不知道教科书的内容是什么,但你可以在第三组中收集相同的元素。