PyTorch LSTM状态

时间:2018-08-14 16:11:26

标签: python lstm pytorch

请考虑以下代码:

lstm = nn.LSTM(10, 5, batch_first=True)
states = (torch.rand(1, 1, 5), torch.rand(1, 1, 5))
h, states = lstm(torch.rand(1, 1, 10), states)
print('h:')
print(h)
print('states[0]:')
print(states[0])

输出:

h:
tensor([[[0.2808, 0.3357, 0.1290, 0.1413, 0.2648]]], grad_fn=<TransposeBackward0>)
states[0]:
tensor([[[0.2808, 0.3357, 0.1290, 0.1413, 0.2648]]], grad_fn=<ViewBackward>)

因为我必须将states传递为forward()的参数,所以我还是更希望使用states[0]而不是h

我刚刚注意到grad_fn是不同的,因此我想知道如果使用hstates进行进一步的输出计算,它对反向传播是否有任何影响。

我几乎无法想象有什么区别,所以我可能会继续讲states[0],但我也想了解为什么会有所不同。

谢谢!

1 个答案:

答案 0 :(得分:1)

使用h或(通常称为output)是最佳实践,更直观,因为states旨在传递给lstm供内部使用(请考虑tensorflow的{{3 }},看看为什么会这样。

那说明您是正确的,实际上并没有什么不同。我不确定grad_fn为何不同,但是凭经验,它们的功能相同:

import torch
from torch import nn

lstm = nn.LSTM(10, 5, batch_first=True)
state = (torch.rand(1, 1, 5), torch.rand(1, 1, 5))
inp = torch.rand(1, 1, 10)
h, states = lstm(inp, state)

param = next(lstm.parameters())

l1 = h.sum()
l1.backward(retain_graph=True)
g1 = param.grad.clone()

param.grad.zero_()

l2 = states[0].sum()
l2.backward(retain_graph=True)
g2 = param.grad.clone()

print((g1 == g2).all())  # 1