torch / nn - 明智地连接Tensors数组

时间:2016-06-10 12:03:09

标签: lua jointable torch

这个问题的主题是加入神经网络的张量,其中包括用于Lua的torch / nn和torch / nngraph库。我几周前开始在Lua编码,所以我的经验很少。在下面的文本中,我将lua表称为数组。

上下文

我正在使用循环神经网络进行语音识别。 在网络中的某个时刻,有Nm张量阵列。

a = {a1, a2, ..., aM},
b = {b1, b2, ..., bM}, 
... N times

其中aibi是张量,{}代表数组。

需要做的是按元素方式连接所有这些数组,以便outputM张量的数组,其中output[i]是加入{的每个第i个张量的结果第二维上的{1}}数组。

N

实施例

output = {z1, z2, ..., zM} 曾代表张量

||

因此,大小为2x2的x = {|1 1|, |2 2|} |1 1| |2 2| Tensors of size 2x2 y = {|3 3 3|, |4 4 4|} |3 3 3| |4 4 4| Tensors of size 2x3 | | Join{x,y} \/ z = {|1 1 3 3 3|, |2 2 4 4 4|} |1 1 3 3 3| |2 2 4 4 4| Tensors of size 2x5 的第一个Tensor与第二个维度上第一个大小为2x3的x Tensor相连,每个数组的第二个Tensor相同,导致{{1}一系列的张量传感器2x5。

问题

现在这是一个基本的连接,但我似乎无法在火炬/ nn库中找到一个允许我这样做的模块。我当然可以编写自己的模块,但是如果已经存在的模块就可以了,那么我宁愿选择它。

我知道连接表的唯一现有模块(显然)是JoinTable。它需要一系列的Tensors并将它们连接在一起。我想以元素方式加入张量数组。

此外,当我们向网络提供输入时,y数组中的张量数量会有所不同,因此上述上下文中的z不是常数。

我认为为了使用模块JoinTable我可以做的是将我的数组转换为张量,然后在转换后的N张量上转换m。但是我再次需要一个进行这种转换的模块和另一个转换回数组的模块,以便将其提供给网络的下一层。

最后的手段

编写一个新模块,迭代所有给定的数组并以元素方式连接。当然它可以做到,但这篇文章的全部目的是找到一种方法来避免编写有臭味的模块。对我来说这样的模块已经不存在似乎很奇怪。

结论

我最终决定按照我在最后手段中写的那样做。我编写了一个新模块,它迭代所有给定的数组并连接元素。

尽管如此,@ fmguler给出的答案也是如此,无需编写新模块。

1 个答案:

答案 0 :(得分:2)

您可以使用nn.SelectTable和nn.JoinTable这样做;

require 'nn'

x = {torch.Tensor{{1,1},{1,1}}, torch.Tensor{{2,2},{2,2}}}
y = {torch.Tensor{{3,3,3},{3,3,3}}, torch.Tensor{{4,4,4},{4,4,4}}}

res = {}
res[1] = nn.JoinTable(2):forward({nn.SelectTable(1):forward(x),nn.SelectTable(1):forward(y)})
res[2] = nn.JoinTable(2):forward({nn.SelectTable(2):forward(x),nn.SelectTable(2):forward(y)})

print(res[1])
print(res[2])

如果您希望在模块中完成此操作,请将其包装在nnGraph中;

require 'nngraph'

x = {torch.Tensor{{1,1},{1,1}}, torch.Tensor{{2,2},{2,2}}}
y = {torch.Tensor{{3,3,3},{3,3,3}}, torch.Tensor{{4,4,4},{4,4,4}}}

xi = nn.Identity()()
yi = nn.Identity()()
res = {}
--you can loop over columns here>>
res[1] = nn.JoinTable(2)({nn.SelectTable(1)(xi),nn.SelectTable(1)(yi)})
res[2] = nn.JoinTable(2)({nn.SelectTable(2)(xi),nn.SelectTable(2)(yi)})
module = nn.gModule({xi,yi},res)

--test like this
result = module:forward({x,y})
print(result)
print(result[1])
print(result[2])

--gives the result
th> print(result)
{
  1 : DoubleTensor - size: 2x5
  2 : DoubleTensor - size: 2x5
}

th> print(result[1])
 1  1  3  3  3
 1  1  3  3  3
[torch.DoubleTensor of size 2x5]

th> print(result[2])
 2  2  4  4  4
 2  2  4  4  4
[torch.DoubleTensor of size 2x5]