在数组中查找最大值索引的最快方法是什么?

时间:2019-09-05 23:11:46

标签: arrays rust

我有一个f32类型的2D数组(来自ndarray::ArrayView2),我想在每一行中找到最大值的索引,然后将索引值放入另一个数组。

Python中的等效项类似于:

import numpy as np

for i in range (0, max_val, batch_size):
   sims = xp.dot(batch, vectors.T) 
   # sims is the dot product of batch and vectors.T
   # the shape is, for example, (1024, 10000)

   best_rows[i: i+batch_size] = sims.argmax(axis = 1)

在Python中,功能.argmax非常快,但在Rust中看不到任何类似的功能。最快的方法是什么?

2 个答案:

答案 0 :(得分:1)

考虑一般Ord类型的简单情况:答案会有所不同,具体取决于您是否知道值是Copy,但这是代码:

fn position_max_copy<T: Ord + Copy>(slice: &[T]) -> Option<usize> {
    slice.iter().enumerate().max_by_key(|(_, &value)| value).map(|(idx, _)| idx)
}

fn position_max<T: Ord>(slice: &[T]) -> Option<usize> {
    slice.iter().enumerate().max_by(|(_, value0), (_, value1)| value0.cmp(value1)).map(|(idx, _)| idx)
}

基本思想是我们将数组中每个项的[引用](实际上是切片-不管是Vec还是数组还是更奇特的东西)都与其索引配对,请使用{{ 1}}函数仅根据值(而不是索引)查找最大值,然后仅返回索引。如果切片为空,std::iter::Iterator将被返回。根据文档,将返回最右边的索引。如果需要最左侧,请在{em}之后None rev()

enumerate()rev()enumerate()max_by_key()记录在here中; max_by()已记录在here中(但您需要在没有文档记录的情况下将其召回清单,以防锈开发); slice::iter()map记录在here中(同上)。哦,Option::map()cmp,但是大多数时候您可以使用不需要的Ord::cmp版本(例如,如果要比较整数)。


现在要注意的是:由于IEEE浮动工作的方式,Copy不是f32。大多数语言都忽略了这一点,并且算法有误。在Ord上提供总订单的最受欢迎的板条箱(通过声明所有NaN相等且大于所有数字)似乎是ordered-float。假设正确实施,它应该非常轻巧。它确实引入了Ord,但这是最受欢迎的数字库的一部分,因此很可能已经被其他依赖项引入了。

在这种情况下,您可以通过在切片迭代器(num_traits)上映射ordered_float::OrderedFloat(元组类型的“构造函数”)来使用它。由于只需要最大元素的位置,因此以后无需提取f32。

答案 1 :(得分:0)

approach from @David A很酷,但是如上所述,有一个陷阱:f32f64不实现Ord::cmp。 (这在您所知的地方确实很痛苦。)

有多种解决方法:您可以自己实现cmp,也可以使用ordered-float等。

就我而言,这是一个较大项目的一部分,我们在使用外部软件包时非常小心。此外,我很确定我们没有任何NaN值。因此,我更喜欢使用fold,如果您仔细查看max_by_key源代码,它们也是他们一直在使用的。

for (i, row) in matrix.axis_iter(Axis(1)).enumerate() {
    let (max_idx, max_val) =
        row.iter()
            .enumerate()
            .fold((0, row[0]), |(idx_max, val_max), (idx, val)| {
                if &val_max > val {
                    (idx_max, val_max)
                } else {
                    (idx, *val)
                }
            });
}