pytorch lstm tutorial初始化变量

时间:2018-01-23 23:45:58

标签: lstm pytorch

我正在浏览lstm的pytorch教程,以及他们使用的代码:

lstm = nn.LSTM(3, 3)  # Input dim is 3, output dim is 3
inputs = [autograd.Variable(torch.randn((1, 3)))
          for _ in range(5)]  # make a sequence of length 5

# initialize the hidden state.
hidden = (autograd.Variable(torch.randn(1, 1, 3)),
          autograd.Variable(torch.randn((1, 1, 3))))
for i in inputs:
    # Step through the sequence one element at a time.
    # after each step, hidden contains the hidden state.
    out, hidden = lstm(i.view(1, 1, -1), hidden)

对于变量hidden,它初始化为元组,结果为:

(Variable containing:
(0 ,.,.) = 
  0.4251 -1.2328 -0.6195
[torch.FloatTensor of size 1x1x3]
, Variable containing:
(0 ,.,.) = 
  1.5133  1.9954 -0.6585
[torch.FloatTensor of size 1x1x3]
)

我不明白的是

  1. (0,。,。)是一个索引吗?从我们说过(torch.randn(1,1,3))之后,它不应该初始化所有三个数字吗?

  2. torch.randn(1,1,3)和torch.randn((1,1,3))有什么区别?

1 个答案:

答案 0 :(得分:1)

首先快速回答2号:它们完全相同。我不知道他们为什么会这样做。

接下来,回答问题1:

hidden是一个包含两个Variables的元组,它们基本上是1 x 1 x 3张量。

让我们关注(0 ,.,.)。如果您使用1 x 1 x 3张量而不是2 x 2张量,则可以打印出以下内容:

0.1 0.2
0.3 0.4

但是在屏幕上代表三维事物很难。即使它有点愚蠢,但在开头有一个aditional 1会改变二维张量到三维张量。所以,相反,Pytorch打印出#34;切片"张量的。在这种情况下,你只有一个"切片"这恰好是零片。因此,您可以获得额外的(0, ,.,.),而不仅仅是打印

  0.4251 -1.2328 -0.6195

如果维度为2 x 1 x 3,则可以预期输出如下:

(0 ,.,.) = 
 -0.3027 -1.1077  0.4724

(1 ,.,.) = 
  1.0063 -0.5936 -1.1589
[torch.FloatTensor of size 2x1x3]

正如你所看到的那样,张量中的每个元素都被初始化了。