我有一个张量,大小为:torch.Size([118160, 1])
。我想做的就是将其分解为n个张量,每个张量包含100个元素,一次滑动50个元素。用PyTorch实现这一目标的最佳方法是什么?
答案 0 :(得分:1)
您的元素数量需要除以100。如果不是这种情况,则可以使用填充进行调整。
您可以先在原始列表上进行拆分。 然后在列表上进行拆分,从前一个列表中删除前50个元素。 如果要保留原始顺序,则可以从A和B采样交替顺序。
A = yourtensor
B = yourtensor[50:] + torch.zeros(50,1)
A_ = A.view(100,-1)
B_ = B.view(100,-1)
答案 1 :(得分:1)
可能的解决方案是:
window_size = 100
stride = 50
splits = [x[i:min(x.size(0),i+window_size)] for i in range(0,x.size(0),stride)]
但是,最后几个元素要短于window_size
。如果不希望这样做,您可以执行以下操作:
splits = [x[i:i+window_size] for i in range(0,x.size(0)-window_size+1,stride)]
编辑:
更具可读性的解决方案:
# if keep_short_tails is set to True, the slices shorter than window_size at the end of the result will be kept
def window_split(x, window_size=100, stride=50, keep_short_tails=True):
length = x.size(0)
splits = []
if keep_short_tails:
for slice_start in range(0, length, stride):
slice_end = min(length, slice_start + window_size)
splits.append(x[slice_start:slice_end])
else:
for slice_start in range(0, length - window_size + 1, stride):
slice_end = slice_start + window_size
splits.append(x[slice_start:slice_end])
return splits