如何使用滑动窗口调整PyTorch张量的大小?

时间:2020-02-10 19:34:05

标签: python pytorch tensor

我有一个张量,大小为:torch.Size([118160, 1])。我想做的就是将其分解为n个张量,每个张量包含100个元素,一次滑动50个元素。用PyTorch实现这一目标的最佳方法是什么?

2 个答案:

答案 0 :(得分:1)

您的元素数量需要除以100。如果不是这种情况,则可以使用填充进行调整。

您可以先在原始列表上进行拆分。 然后在列表上进行拆分,从前一个列表中删除前50个元素。 如果要保留原始顺序,则可以从A和B采样交替顺序。

A = yourtensor
B = yourtensor[50:] + torch.zeros(50,1)
A_ = A.view(100,-1)
B_ = B.view(100,-1)

答案 1 :(得分:1)

可能的解决方案是:

window_size = 100
stride = 50
splits = [x[i:min(x.size(0),i+window_size)] for i in range(0,x.size(0),stride)]

但是,最后几个元素要短于window_size。如果不希望这样做,您可以执行以下操作:

splits = [x[i:i+window_size] for i in range(0,x.size(0)-window_size+1,stride)]

编辑:

更具可读性的解决方案:

# if keep_short_tails is set to True, the slices shorter than window_size at the end of the result will be kept 
def window_split(x, window_size=100, stride=50, keep_short_tails=True):
  length = x.size(0)
  splits = []

  if keep_short_tails:
    for slice_start in range(0, length, stride):
      slice_end = min(length, slice_start + window_size)
      splits.append(x[slice_start:slice_end])
  else:
    for slice_start in range(0, length - window_size + 1, stride):
      slice_end = slice_start + window_size
      splits.append(x[slice_start:slice_end])

  return splits