如何使用环绕创建在ndarray上的滑动窗口?

时间:2020-07-14 23:03:52

标签: python numpy

我正在尝试编写一些逻辑,以返回一个数组,该数组向右移动了一步,并具有环绕效果。我依靠接收IndexError来实现环绕,但是没有引发任何错误!

def get_batches(arr, batch_size, seq_length):
    """
    Return arr data as batches of shape (batch_size, seq_length)
    """
    
    n_chars = batch_size * seq_length
    n_batches = int(np.floor(len(arr)/ n_chars))
    n_keep = n_chars * n_batches
    
    arr = arr[:n_keep].reshape(batch_size, -1)
    
    for b in range(n_batches):
        start = b * seq_length
        stop = start + seq_length
        
        x = arr[:, start:stop]
        try: 
            y = arr[:, start + 1: stop + 1]
        except IndexError:
            y = np.concatenate(x[:, 1:], arr[:, 0], axis=1)
        
        yield x, y

因此,此代码非常有用,除了产生最后一个y数组时...我得到一个(2,2)数组,而不是预期的(2,3)。那是因为从未抛出IndexError。

test = np.arange(12)
batches = get_batches(test, 2, 3)

for x, y in batches:
    print('x=', x)
    print('y=', y, '\n')

收益

x=
 [[0 1 2]
 [6 7 8]]
y=           # as expected
 [[1 2 3]
 [7 8 9]] 

x=
 [[ 3  4  5]
 [ 9 10 11]]
y=           # truncated :(
 [[ 4  5]
 [10 11]] 

有人对如何做到这一点有其他建议吗?最好是像我失败的解决方案一样简单的事情?

1 个答案:

答案 0 :(得分:1)

尝试一下:

from skimage.util.shape import view_as_windows
def get_batches2(arr, batch_size, seq_length):
    """
    Return arr data as batches of shape (batch_size, seq_length)
    """
    n_chars = batch_size * seq_length
    n_batches = int(np.floor(len(arr)/ n_chars))
    n_keep = n_chars * n_batches
    
    arr = arr[:n_keep].reshape(batch_size, -1)
    x = view_as_windows(arr, (batch_size, seq_length), seq_length)[0]
    y = view_as_windows(np.roll(arr,-1,axis=1), (batch_size, seq_length), seq_length)[0]

    return x, y

view_as_windows使用相同的共享内存(这是一个视图。您可以检查它们是否共享相同的内存)。因此,无论是通过循环产生还是返回它都没有关系。如果这是问题(特别是您的窗口不重叠),它将不使用额外的内存,并且它应该比生成器快得多。您甚至也可以通过简单地重塑而无需view_as_windows来实现此目的。