我想识别我的numpy数组中的索引,其值是集合中包含的值之一;例如,这设置为(5,6,7,8)。
现在我正在做
np.where(np.isin(arr, [5,6,7,8]))
工作正常。我想知道是否有更好的方法来实现这一功能。
答案 0 :(得分:4)
您拥有的代码是正确和合理的。你应该保留它。
答案 1 :(得分:3)
如果您不知道替代方案是什么,您无法知道当前的解决方案是否良好。
首先,
np.where(np.isin(arr, val))
适用于任何一般情况。 np.isin
对arr
中的元素val
进行线性搜索。
您也可以将np.where
替换为np.nonzero
,这对于较大的N来说要快一些。
接下来,有
(arr[:, None] == val).argmax(0)
对于小大小的arr和val(N <100),非常快。
最后,如果arr
已排序,我建议np.searchsorted
。
np.searchsorted(arr, val)
arr = np.arange(100000)
val = np.random.choice(arr, 1000)
%timeit np.where(np.isin(arr, val))
%timeit np.nonzero(np.isin(arr, val))
%timeit (arr[:, None] == val).argmax(0)
%timeit np.searchsorted(arr, val)
8.3 ms ± 320 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
7.88 ms ± 791 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
861 ms ± 6.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
235 µs ± 31.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
(arr[:, None] == val).argmax(0)
的问题是内存井喷 - 比较被广播,引入了一个非常非常稀疏的矩阵,当N很大时这是非常浪费的(因此不要将它用于大N)。
答案 2 :(得分:2)
您的方法有效,它也适用于多维数组。
dependencies {
implementation fileTree(dir: 'libs', include: ['*.jar'])
implementation 'com.android.support:appcompat-v7:27.1.1'
implementation 'com.android.support.constraint:constraint-layout:1.1.0'
testImplementation 'junit:junit:4.12'
androidTestImplementation 'com.android.support.test:runner:1.0.2'
androidTestImplementation 'com.android.support.test.espresso:espresso-core:3.0.2'
implementation 'com.squareup.picasso:picasso:2.5.2'
implementation 'com.android.support:recyclerview-v7:27.1.1'
implementation 'com.squareup.retrofit2:retrofit:2.4.0'
implementation 'com.squareup.okhttp3:okhttp:3.10.0'
implementation 'com.google.code.gson:gson:2.8.2'
implementation 'com.squareup.retrofit2:converter-gson:2.1.0'
implementation 'com.squareup.okhttp3:logging-interceptor:3.3.0'
}
这直接来自此处的文档:https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.where.html#numpy.where