我正在尝试将seq2seq模型用于我自己的任务https://github.com/spro/practical-pytorch/blob/master/seq2seq-translation/seq2seq-translation.ipynb
我在解码器阶段有两个张量
rnn_output: (1, 1, 256) # time_step x batch_size x hidden_dimension
encoder_inputs: (10, 1, 256) # seq_len x batch_size x hidden_dimension
它们应该被乘以得到形状的注意力得分(在softmax之前)
attn_score: (10, 1, 1)
最好的方法是什么?笔记本似乎使用for循环,有没有更好的矩阵乘法运算?
答案 0 :(得分:3)
没有使用pytorch
的经验,但是可以这样做吗?
torch.einsum('ijk,abk->abc', (rnn_output, encoder_inputs))
将点积乘在最后一个轴上并添加几个空轴。
使用纯粹的numpy可以实现类似的功能(pytorch 0.4
还没有...
符号)
np.einsum('...ik,...jk', rnn_output.numpy(), encoder_inputs.numpy())
或np.tensordot
np.tensordot(rnn_output.numpy(), encoder_inputs.numpy(), axes=[2,2])
但是在这里你将获得输出形状:(1, 1, 10, 1)
你可以通过挤压和重新扩展来解决这个问题(几乎可以肯定必须有一些更清洁的方法来执行此操作)
np.tensordot(rnn_output.numpy(), encoder_inputs.numpy(), axes=[2,2]).squeeze()[..., None, None]
答案 1 :(得分:2)
使用torch.bmm()
的示例:
import torch
from torch.autograd import Variable
import numpy as np
seq_len = 10
rnn_output = torch.rand((1, 1, 256))
encoder_outputs = torch.rand((seq_len, 1, 256))
# As computed in the tutorial:
attn_score = Variable(torch.zeros(seq_len))
for i in range(seq_len):
attn_score[i] = rnn_output.squeeze().dot(encoder_outputs[i].squeeze())
# note: the code would fail without the "squeeze()". I would assume the tensors in
# the tutorial are actually (,256) and (10, 256)
# Alternative using batched matrix multiplication (bmm) with some data reformatting first:
attn_score_v2 = torch.bmm(rnn_output.expand(seq_len, 1, 256),
encoder_outputs.view(seq_len, 256, 1)).squeeze()
# ... Interestingly though, there are some numerical discrepancies between the 2 methods:
np.testing.assert_array_almost_equal(attn_score.data.numpy(),
attn_score_v2.data.numpy(), decimal=5)
# AssertionError:
# Arrays are not almost equal to 5 decimals
#
# (mismatch 30.0%)
# x: array([60.32436, 69.04288, 72.04784, 70.19503, 71.75543, 67.45459,
# 63.01708, 71.70189, 63.07552, 67.48799], dtype=float32)
# y: array([60.32434, 69.04287, 72.0478 , 70.19504, 71.7554 , 67.4546 ,
# 63.01709, 71.7019 , 63.07553, 67.488 ], dtype=float32)