我正在尝试编写一些逻辑,以返回一个数组,该数组向右移动了一步,并具有环绕效果。我依靠接收IndexError来实现环绕,但是没有引发任何错误!
def get_batches(arr, batch_size, seq_length):
"""
Return arr data as batches of shape (batch_size, seq_length)
"""
n_chars = batch_size * seq_length
n_batches = int(np.floor(len(arr)/ n_chars))
n_keep = n_chars * n_batches
arr = arr[:n_keep].reshape(batch_size, -1)
for b in range(n_batches):
start = b * seq_length
stop = start + seq_length
x = arr[:, start:stop]
try:
y = arr[:, start + 1: stop + 1]
except IndexError:
y = np.concatenate(x[:, 1:], arr[:, 0], axis=1)
yield x, y
因此,此代码非常有用,除了产生最后一个y
数组时...我得到一个(2,2)
数组,而不是预期的(2,3)
。那是因为从未抛出IndexError。
test = np.arange(12)
batches = get_batches(test, 2, 3)
for x, y in batches:
print('x=', x)
print('y=', y, '\n')
收益
x=
[[0 1 2]
[6 7 8]]
y= # as expected
[[1 2 3]
[7 8 9]]
x=
[[ 3 4 5]
[ 9 10 11]]
y= # truncated :(
[[ 4 5]
[10 11]]
有人对如何做到这一点有其他建议吗?最好是像我失败的解决方案一样简单的事情?
答案 0 :(得分:1)
尝试一下:
from skimage.util.shape import view_as_windows
def get_batches2(arr, batch_size, seq_length):
"""
Return arr data as batches of shape (batch_size, seq_length)
"""
n_chars = batch_size * seq_length
n_batches = int(np.floor(len(arr)/ n_chars))
n_keep = n_chars * n_batches
arr = arr[:n_keep].reshape(batch_size, -1)
x = view_as_windows(arr, (batch_size, seq_length), seq_length)[0]
y = view_as_windows(np.roll(arr,-1,axis=1), (batch_size, seq_length), seq_length)[0]
return x, y
view_as_windows
使用相同的共享内存(这是一个视图。您可以检查它们是否共享相同的内存)。因此,无论是通过循环产生还是返回它都没有关系。如果这是问题(特别是您的窗口不重叠),它将不使用额外的内存,并且它应该比生成器快得多。您甚至也可以通过简单地重塑而无需view_as_windows
来实现此目的。