我必须找到每个帧中第一次出现的子数组的索引。数据的大小为(batch_size,400)。我需要找到大小为400的每个帧中三个连续的出现索引。
数据-> [0 0 0 1 1 1 0 1 1 1 1 1][0 0 0 0 1 1 1 0 0 1 1 1] [0 1 1 1 0 0 0 1 1 1 1 1]
输出应为[3 4 1]
本机解决方案正在使用for循环,但是由于数据量很大,因此非常耗时。
numpy
或tensorflow
中任何快速有效的实施方式
答案 0 :(得分:0)
对此没有简单的numpy解决方案。但是,如果您真的需要快速,可以使用numba进行以下操作:
函数find_first
基本上可以完成for循环的工作。但是由于您使用的是numba,因此该方法已编译,因此速度更快。
然后,您只需使用np.apply_along_axis
将方法应用于每个批次:
import numpy as np
from numba import jit
@jit(nopython=True)
def find_first(seq, arr):
"""return the index of the first occurence of item in arr"""
for i in range(len(arr)-2):
if np.all(seq == arr[i:i+3]):
return i
return -1
# construct test array
test = np.round(np.random.random((64,400)))
# this will give you the array of indices
np.apply_along_axis(lambda m: find_first(np.array([1,1,1]), m), axis=1, arr = test)
我从this answer修改了方法