我需要为每个批处理元素切片一个大小恒定的窗口,但是从不同的位置开始。例如,对于长度为2的窗口,我希望能够执行以下操作:
batch = tf.constant([[1, 2, 3],
[4, 5, 6]])
window_size = 2
window_starts = tf.constant([1, 0]) # The index from which to start slicing for each batch element.
slice_windows(batch, window_size, window_starts)
# this should return [[2, 3],
# [4, 5]]
我事先不知道window_starts是什么(它们来自数据),所以我不能只列举我需要的所有索引并使用tf.gather_nd。
此外,在对窗口进行计算之后,我需要将它们用0填充回原位(因此每个批处理元素的填充量不同):
...computation on windows...
restore_window_positions(windows, window_starts, original_size=3)
# this should return [[0, 2, 3],
# [4, 5, 0]]