TensorFlow - 每个批次元素在不同位置切片

时间:2018-01-12 20:47:15

标签: tensorflow

我需要为每个批处理元素切片一个大小恒定的窗口,但是从不同的位置开始。例如,对于长度为2的窗口,我希望能够执行以下操作:

batch = tf.constant([[1, 2, 3],
                     [4, 5, 6]])
window_size = 2
window_starts = tf.constant([1, 0])  # The index from which to start slicing for each batch element.

slice_windows(batch, window_size, window_starts) 

# this should return [[2, 3],
#                     [4, 5]]

我事先不知道window_starts是什么(它们来自数据),所以我不能只列举我需要的所有索引并使用tf.gather_nd。

此外,在对窗口进行计算之后,我需要将它们用0填充回原位(因此每个批处理元素的填充量不同):

...computation on windows...

restore_window_positions(windows, window_starts, original_size=3)

# this should return [[0, 2, 3],
#                     [4, 5, 0]]

0 个答案:

没有答案