如何在火炬7中进行多任务学习?

时间:2017-12-12 12:16:27

标签: lua torch

Simple multi-task network can be done here.但我想要这样的事情enter image description here。 现在我构建如下模型:

model = nn.Sequential()
model:add(nn.Linear(3,5))
prl1 = nn.ConcatTable()
prl1:add(nn.Linear(5,1))
prl2 = nn.ConcatTable()
prl2:add(nn.Linear(5,1))
prl2:add(nn.Linear(5,1))
prl1:add(prl2)
model:add(prl1)

我的输出是:

input = torch.rand(5,3)
output = model:forward(input)
output
{
  1 : DoubleTensor - size: 5x1
  2 : 
    {
      1 : DoubleTensor - size: 5x1
      2 : DoubleTensor - size: 5x1
    }
}

我应该如何构建我的标准?

1 个答案:

答案 0 :(得分:0)

我似乎通过两个步骤弄明白:

1.在上述网络中使用nn.Concat而不是nn.ConcatTable,这使输出成为一个简单的NxM张量,例如:使用nn.Concat代替nn.ConcatTable时,5x3张量将进入上述网络。

2.获得NxM张量后,我使用nn.ConcatTable,nn.Concat和nn.Select的组合使输出成为包含每个结果Tensor的简单表。

以下是第2步的简单示例:

model = nn.Sequential()
model:add(nn.Linear(3,5))

prl = nn.ConcatTable()

spl1 = nn.Concat(2)

seq1 = nn.Sequential()
seq1:add(nn.Select(2, 1))
seq1:add(nn.Reshape(1))

seq2 = nn.Sequential()
seq2:add(nn.Select(2, 2))
seq2:add(nn.Reshape(1))

seq3 = nn.Sequential()
seq3:add(nn.Select(2, 3))
seq3:add(nn.Reshape(1))

spl1:add(seq1)
spl1:add(seq2)
spl1:add(seq3)
prl:add(spl1)

spl2 = nn.Concat(2)

seq4 = nn.Sequential()
seq4:add(nn.Select(2, 4))
seq4:add(nn.Reshape(1))

seq5 = nn.Sequential()
seq5:add(nn.Select(2, 5))
seq5:add(nn.Reshape(1))

spl2:add(seq4)
spl2:add(seq5)
prl:add(spl2)

model:add(prl)

input = torch.rand(5,3)
output = model:forward(input)

输出如下:

th> output
{
  1 : DoubleTensor - size: 5x3
  2 : DoubleTensor - size: 5x2
}
相关问题