在没有for循环的情况下查找批处理数据每一帧中子数组首次出现的索引的最佳方法

时间:2019-02-14 10:59:08

标签: numpy for-loop tensorflow math batch-processing

我必须找到每个帧中第一次出现的子数组的索引。数据的大小为(batch_size,400)。我需要找到大小为400的每个帧中三个连续的出现索引。 数据-> [0 0 0 1 1 1 0 1 1 1 1 1][0 0 0 0 1 1 1 0 0 1 1 1] [0 1 1 1 0 0 0 1 1 1 1 1]

输出应为[3 4 1]

本机解决方案正在使用for循环,但是由于数据量很大,因此非常耗时。

numpytensorflow中任何快速有效的实施方式

1 个答案:

答案 0 :(得分:0)

对此没有简单的numpy解决方案。但是,如果您真的需要快速,可以使用numba进行以下操作:

函数find_first基本上可以完成for循环的工作。但是由于您使用的是numba,因此该方法已编译,因此速度更快。 然后,您只需使用np.apply_along_axis将方法应用于每个批次:

import numpy as np
from numba import jit


@jit(nopython=True)
def find_first(seq, arr):
    """return the index of the first occurence of item in arr"""
    for i in range(len(arr)-2):
        if np.all(seq == arr[i:i+3]):
            return i
    return -1

# construct test array
test = np.round(np.random.random((64,400)))

# this will give you the array of indices
np.apply_along_axis(lambda m: find_first(np.array([1,1,1]), m), axis=1, arr = test)

我从this answer修改了方法