如何使用张量流有效提取给定长度的所有切片

时间:2018-10-27 15:57:12

标签: python tensorflow

我试图提取沿2维张量的第0轴的所有长度为4的切片。到目前为止,我可以将纯Python与tensorflow混合在一起。

r = test.shape[0] # test should be a tensor
n = 4
a_list = list(range(r))
the_list = np.array([a_list[slice(i, i+n)] for i in range(r - n+1)])
test_stacked = tf.stack(tf.gather(test, the_list))

在不使用纯Python的情况下,这样做的有效方法是什么?请注意,“测试”数组实际上应该是张量,因此在执行图的第一部分之前,其形状是未知的。

完整的示例:

array = np.array([[0, 1],[1, 2],[2, 3],[3, 4],[4, 5],[5, 6]])
array.shape # (6,2)

r = array.shape[0]
n = 4
a_list = list(range(r))
the_list = np.array([a_list[slice(i, i+n)] for i in range(r - n+1)])

result = array[the_list] # all possible slices of length 4 of the array along 0th axis
result.shape # (3, 4, 2)

结果:

[[[0 1]
  [1 2]
  [2 3]
  [3 4]]

 [[1 2]
  [2 3]
  [3 4]
  [4 5]]

 [[2 3]
  [3 4]
  [4 5]
  [5 6]]]

2 个答案:

答案 0 :(得分:1)

我相信您正在寻找gather_nd

# a is a tensor of size (6, 2)

def get_indices(l, d):
    return [[[j] for j in range(i, i + d)] for i in range(l - d + 1)]

b = tf.gather_nd(a, get_indices(6, 4))
# b is a tensor of shape (3, 4, 2)

答案 1 :(得分:1)

您可能想尝试更一般的tf.extract_image_patches

import tensorflow as tf

a = tf.constant([[0, 1],[1, 2],[2, 3],[3, 4],[4, 5],[5, 6]])
# tf.extract_image_patches requires a [batch, in_rows, in_cols, depth] tensor
a = a[None, :, :, None]
b = tf.extract_image_patches(a,
  ksizes=[1, 4, 2, 1],
  strides=[1, 1, 1, 1],
  rates=[1, 1, 1, 1],
  padding='VALID')
b = tf.reshape(tf.squeeze(b), [-1, 4, 2])

sess = tf.InteractiveSession()
print(b.eval())