帧滑动的Tensorflow变换(numpy stride_tricks)

时间:2016-02-03 10:27:04

标签: reshape tensorflow

我在张量流图中的输入是一个包含多个重叠窗口的向量。如何仅使用tensorflow操作创建此数组?

input = [1,2,3,4,5,6,7,8]
shift = 2
window_width = 4
count = (len(input) - window_width) // 2 + 1 = 3

output = [[1,2,3,4],
          [3,4,5,6],
          [5,6,7,8]]

在numpy中我会使用stride_tricks,但类似的东西在tensorflow中是不可用的。我该怎么做呢?

3 个答案:

答案 0 :(得分:4)

TensorFlow没有stride_tricks。以下内容适用于您的特定用例。

>>> b=tf.concat(0, [tf.reshape(input[i:i+4], [1, window_width]) for i in range(0, len(input) - window_width + 1, shift)])
>>> with tf.Session(""): b.eval()
... 
array([[1, 2, 3, 4],
       [3, 4, 5, 6],
       [5, 6, 7, 8]], dtype=int32)

如果您的输入很大,您可能还需要查看slice_input_producer

答案 1 :(得分:0)

您可以使用tf.map_fn()完成此操作:

input = [1,2,3,4,5,6,7,8]
shift = 2
window_width = 4
limit = len(input) - window_width + 1

input_tensor = tf.placeholder(tf.int32, shape=(8,))
output_tensor = tf.map_fn(lambda i: input_tensor[i:i+window_width], elems=tf.range(start=0, limit=limit, delta=shift))

with tf.Session() as sess:
    answer_test = sess.run(output_tensor, feed_dict = {input_tensor:input})
    print(answer_test)

:    [[1 2 3 4]
     [3 4 5 6]
     [5 6 7 8]]

答案 2 :(得分:0)

如果有人仍然需要此功能,则有一个功能tf.signal.frame()

对于问题中给出的示例,代码为:

output = tf.signal.frame(input, window_width, shift)