我有一个f32
类型的2D数组(来自ndarray::ArrayView2
),我想在每一行中找到最大值的索引,然后将索引值放入另一个数组。
Python中的等效项类似于:
import numpy as np
for i in range (0, max_val, batch_size):
sims = xp.dot(batch, vectors.T)
# sims is the dot product of batch and vectors.T
# the shape is, for example, (1024, 10000)
best_rows[i: i+batch_size] = sims.argmax(axis = 1)
在Python中,功能.argmax
非常快,但在Rust中看不到任何类似的功能。最快的方法是什么?
答案 0 :(得分:1)
考虑一般Ord
类型的简单情况:答案会有所不同,具体取决于您是否知道值是Copy
,但这是代码:
fn position_max_copy<T: Ord + Copy>(slice: &[T]) -> Option<usize> {
slice.iter().enumerate().max_by_key(|(_, &value)| value).map(|(idx, _)| idx)
}
fn position_max<T: Ord>(slice: &[T]) -> Option<usize> {
slice.iter().enumerate().max_by(|(_, value0), (_, value1)| value0.cmp(value1)).map(|(idx, _)| idx)
}
基本思想是我们将数组中每个项的[引用](实际上是切片-不管是Vec还是数组还是更奇特的东西)都与其索引配对,请使用{{ 1}}函数仅根据值(而不是索引)查找最大值,然后仅返回索引。如果切片为空,std::iter::Iterator
将被返回。根据文档,将返回最右边的索引。如果需要最左侧,请在{em}之后None
rev()
。
enumerate()
,rev()
,enumerate()
和max_by_key()
记录在here中; max_by()
已记录在here中(但您需要在没有文档记录的情况下将其召回清单,以防锈开发); slice::iter()
被map
记录在here中(同上)。哦,Option::map()
是cmp
,但是大多数时候您可以使用不需要的Ord::cmp
版本(例如,如果要比较整数)。
现在要注意的是:由于IEEE浮动工作的方式,Copy
不是f32
。大多数语言都忽略了这一点,并且算法有误。在Ord
上提供总订单的最受欢迎的板条箱(通过声明所有NaN相等且大于所有数字)似乎是ordered-float。假设正确实施,它应该非常轻巧。它确实引入了Ord
,但这是最受欢迎的数字库的一部分,因此很可能已经被其他依赖项引入了。
在这种情况下,您可以通过在切片迭代器(num_traits
)上映射ordered_float::OrderedFloat
(元组类型的“构造函数”)来使用它。由于只需要最大元素的位置,因此以后无需提取f32。
答案 1 :(得分:0)
approach from @David A很酷,但是如上所述,有一个陷阱:f32
和f64
不实现Ord::cmp
。 (这在您所知的地方确实很痛苦。)
有多种解决方法:您可以自己实现cmp
,也可以使用ordered-float
等。
就我而言,这是一个较大项目的一部分,我们在使用外部软件包时非常小心。此外,我很确定我们没有任何NaN
值。因此,我更喜欢使用fold
,如果您仔细查看max_by_key
源代码,它们也是他们一直在使用的。
for (i, row) in matrix.axis_iter(Axis(1)).enumerate() {
let (max_idx, max_val) =
row.iter()
.enumerate()
.fold((0, row[0]), |(idx_max, val_max), (idx, val)| {
if &val_max > val {
(idx_max, val_max)
} else {
(idx, *val)
}
});
}