Rust-如何在集合中查找第n个最频繁的元素

时间:2020-10-08 12:13:54

标签: rust

我无法想象以前没有问过这个问题,但是我到处搜索并且找不到答案。

我有一个迭代器,其中包含重复的元素。我想计算每个元素在此迭代中出现的次数,并返回第n个最频繁的元素。

我有一个工作代码可以做到这一点,但是我真的怀疑它是实现这一目标的最佳方法。

use std::collections::{BinaryHeap, HashMap};

// returns n-th most frequent element in collection
pub fn most_frequent<T: std::hash::Hash + std::cmp::Eq + std::cmp::Ord>(array: &[T], n: u32) -> &T {
    // intialize empty hashmap
    let mut map = HashMap::new();

    // count occurence of each element in iterable and save as (value,count) in hashmap
    for value in array {
        // taken from https://doc.rust-lang.org/std/collections/struct.HashMap.html#method.entry
        // not exactly sure how this works
        let counter = map.entry(value).or_insert(0);
        *counter += 1;
    }

    // determine highest frequency of some element in the collection
    let mut heap: BinaryHeap<_> = map.values().collect();
    let mut max = heap.pop().unwrap();
    // get n-th largest value
    for _i in 1..n {
        max = heap.pop().unwrap();
    }

    // find that element (get key from value in hashmap)
    // taken from https://stackoverflow.com/questions/59401720/how-do-i-find-the-key-for-a-value-in-a-hashmap
    map.iter()
        .find_map(|(key, &val)| if val == *max { Some(key) } else { None })
        .unwrap()
}

是否有更好的方法或更优化的std方法来实现我想要的?也许有一些我可以使用的社区制作的板条箱。

1 个答案:

答案 0 :(得分:1)

您的实现的时间复杂度为Ω( n log n ),其中 n 是数组的长度。该问题的最佳解决方案是检索第 k 个最频繁元素的复杂度Ω( n log k )。这种最佳解决方案的通常实现确实涉及到二进制堆,但并不以您使用它的方式进行。

以下是常见算法的建议实现:

use std::cmp::{Eq, Ord, Reverse};
use std::collections::{BinaryHeap, HashMap};
use std::hash::Hash;

pub fn most_frequent<T>(array: &[T], k: usize) -> Vec<(usize, &T)>
where
    T: Hash + Eq + Ord,
{
    let mut map = HashMap::new();
    for x in array {
        *map.entry(x).or_default() += 1;
    }

    let mut heap = BinaryHeap::with_capacity(k + 1);
    for (x, count) in map.into_iter() {
        heap.push(Reverse((count, x)));
        if heap.len() > k {
            heap.pop();
        }
    }
    heap.into_sorted_vec().into_iter().map(|r| r.0).collect()
}

Playground

我更改了函数的原型,以返回k最频繁元素及其计数的向量,因为无论如何这都是您需要跟踪的。如果只希望第k个最频繁的元素,则可以使用[k - 1][1]为结果建立索引。

该算法本身首先以与您的代码相同的方式构建元素计数图–我只是以更简洁的形式编写了它。

接下来,我们为最常见的元素设置一个BinaryHeap。每次迭代之后,此堆最多包含k个元素,这是迄今为止看到的最频繁的元素。如果堆中有超过k个元素,则删除最不频繁的元素。由于我们总是删除到目前为止看到的最不频繁的元素,因此堆总是保留到目前为止看到的k最频繁的元素。我们需要使用Reverse包装器来获取最小堆,例如documented in the documentation of BinaryHeap

最后,我们将结果收集到向量中。 into_sorted_vec()函数基本上可以为我们完成这项工作,但是我们仍然想从其Reverse包装器中解包项目–包装器是我们函数的实现细节,不应返回给调用方。 / p>

(在Rust Nightly中,我们也可以使用into_iter_sorted() method,节省一个向量分配。)

此答案中的代码可确保堆基本上限于k个元素,因此插入堆的复杂度为Ω(log k)。在您的代码中,您将数组中的所有元素一次推入堆中,而没有限制堆的大小,因此插入时的复杂度为Ω(log n)。本质上,您使用二进制堆对计数列表进行排序。哪个可行,但肯定不是实现这一目标的最简单或最快的方法,因此走这条路线没有什么道理。