我输入的是形状(B(色),F(特征),N(颂数),T(时间戳))。现在,如果我应用形状为(1,2)的核的2d卷积,我将总共有(F_out,F_in,1,2)个权重来学习。我想对此进行扩展,以便对输入中的每个节点都有自己的形状为(1,2)的过滤器。你们有谁知道我应该从哪里开始吗?到目前为止,我遍历了所有N,并将滤波器应用于其各自的输入。不幸的是,这种方法非常慢。
答案 0 :(得分:2)
您正在寻找“分组卷积”。
有关groups
参数的nn.Conv2d
的文档:
在
groups=2
,该操作等效于具有两个并排的conv层,每个conv层看到一半的输入通道,并产生一半的输出通道,并且随后都被串联。 / p>
在您的情况下,您需要groups=
个节点。
这不是那么简单,因为您要“合并”要素和节点,并且只在“特征” +“节点”维度上进行一维分组卷积。
此外,您需要在“节点”和“特征”之间进行置换,以便根据节点对特征进行分组。
b = 10;
inf = 8;
outf = 13;
n = 3;
t = 50;
x = torch.rand((b, inf, n, t)) # input tensor
gconv = nn.Conv1d(inf, outf, kernel_size=(2), groups=n) #grouped conv
x_ready = x.permute(0, 2, 1, 3).view(b, inf*n, t)
y_grouped = gconv(x_ready)
# "fix" y
y = y_grouped.view(n, n, outf, t).permute(0, 2, 1, 3) # now y is b-outf-n-t