Simple multi-task network can be done here.但我想要这样的事情enter image description here。 现在我构建如下模型:
model = nn.Sequential()
model:add(nn.Linear(3,5))
prl1 = nn.ConcatTable()
prl1:add(nn.Linear(5,1))
prl2 = nn.ConcatTable()
prl2:add(nn.Linear(5,1))
prl2:add(nn.Linear(5,1))
prl1:add(prl2)
model:add(prl1)
我的输出是:
input = torch.rand(5,3)
output = model:forward(input)
output
{
1 : DoubleTensor - size: 5x1
2 :
{
1 : DoubleTensor - size: 5x1
2 : DoubleTensor - size: 5x1
}
}
我应该如何构建我的标准?
答案 0 :(得分:0)
我似乎通过两个步骤弄明白:
1.在上述网络中使用nn.Concat而不是nn.ConcatTable,这使输出成为一个简单的NxM张量,例如:使用nn.Concat代替nn.ConcatTable时,5x3张量将进入上述网络。
2.获得NxM张量后,我使用nn.ConcatTable,nn.Concat和nn.Select的组合使输出成为包含每个结果Tensor的简单表。
以下是第2步的简单示例:
model = nn.Sequential()
model:add(nn.Linear(3,5))
prl = nn.ConcatTable()
spl1 = nn.Concat(2)
seq1 = nn.Sequential()
seq1:add(nn.Select(2, 1))
seq1:add(nn.Reshape(1))
seq2 = nn.Sequential()
seq2:add(nn.Select(2, 2))
seq2:add(nn.Reshape(1))
seq3 = nn.Sequential()
seq3:add(nn.Select(2, 3))
seq3:add(nn.Reshape(1))
spl1:add(seq1)
spl1:add(seq2)
spl1:add(seq3)
prl:add(spl1)
spl2 = nn.Concat(2)
seq4 = nn.Sequential()
seq4:add(nn.Select(2, 4))
seq4:add(nn.Reshape(1))
seq5 = nn.Sequential()
seq5:add(nn.Select(2, 5))
seq5:add(nn.Reshape(1))
spl2:add(seq4)
spl2:add(seq5)
prl:add(spl2)
model:add(prl)
input = torch.rand(5,3)
output = model:forward(input)
输出如下:
th> output
{
1 : DoubleTensor - size: 5x3
2 : DoubleTensor - size: 5x2
}