我经历了official doc。我很难理解此功能的用途以及它的工作方式。有人可以用Layman解释吗?
尽管我使用的Pytorch版本与文档匹配,但我提供的示例却出现错误。也许纠正错误(应该这样做)应该教给我一些东西?文档中给出的代码段是:
fold = nn.Fold(output_size=(4, 5), kernel_size=(2, 2))
input = torch.randn(1, 3 * 2 * 2, 1)
output = fold(input)
output.size()
,固定的代码段是:
fold = nn.Fold(output_size=(4, 5), kernel_size=(2, 2))
input = torch.randn(1, 3 * 2 * 2, 3 * 2 * 2)
output = fold(input)
output.size()
谢谢!
答案 0 :(得分:11)
unfold
将张量想象成一个较长的张量,其中重复的列/行值“折叠”在彼此的顶部,然后“展开”:
size
决定折叠的大小step
决定折叠的频率例如对于 2x5 张量,用 step=1
展开它,并在 size=2
上修补 dim=1
:
x = torch.tensor([[1,2,3,4,5],
[6,7,8,9,10]])
>>> x.unfold(1,2,1)
tensor([[[ 1, 2], [ 2, 3], [ 3, 4], [ 4, 5]],
[[ 6, 7], [ 7, 8], [ 8, 9], [ 9, 10]]])
fold
与此操作大致相反,但“重叠”值在输出中求和。
答案 1 :(得分:4)
unfold
和fold
用于促进“滑动窗口”操作(如卷积)。
假设您要将功能foo
应用于特征图/图像中的每个5x5窗口:
from torch.nn import functional as f
windows = f.unfold(x, kernel_size=5)
现在windows
具有size
个批处理(5 * 5 * {x.size(1)
)-num_windows,您可以在foo
上应用windows
:
processed = foo(windows)
现在,您需要将processed
折回到原始的x
大小:
out = f.fold(processed, x.shape[-2:], kernel_size=5)
您需要注意padding
和kernel_size
的问题,这可能会影响您将processed
折回x
的大小。
此外,fold
求和在重叠的元素上,因此您可能想将fold
的输出除以补丁大小。
答案 2 :(得分:4)
x = torch.arange(1, 9).float()
print(x)
# dimension, size, step
print(x.unfold(0, 2, 1))
print(x.unfold(0, 3, 2))
出局:
tensor([1., 2., 3., 4., 5., 6., 7., 8.])
tensor([[1., 2.],
[2., 3.],
[3., 4.],
[4., 5.],
[5., 6.],
[6., 7.],
[7., 8.]])
tensor([[1., 2., 3.],
[3., 4., 5.],
[5., 6., 7.]])
import torch
patch=(3,3)
x=torch.arange(16).float()
print(x, x.shape)
x2d = x.reshape(1,1,4,4)
print(x2d, x2d.shape)
h,w = patch
c=x2d.size(1)
print(c) # channels
# unfold(dimension, size, step)
r = x2d.unfold(2,h,1).unfold(3,w,1).transpose(1,3).reshape(-1, c, h, w)
print(r.shape)
print(r) # result
tensor([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13.,
14., 15.]) torch.Size([16])
tensor([[[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[12., 13., 14., 15.]]]]) torch.Size([1, 1, 4, 4])
1
torch.Size([4, 1, 3, 3])
tensor([[[[ 0., 1., 2.],
[ 4., 5., 6.],
[ 8., 9., 10.]]],
[[[ 4., 5., 6.],
[ 8., 9., 10.],
[12., 13., 14.]]],
[[[ 1., 2., 3.],
[ 5., 6., 7.],
[ 9., 10., 11.]]],
[[[ 5., 6., 7.],
[ 9., 10., 11.],
[13., 14., 15.]]]])