我对自我关注或仅关注的计算感到困惑。
让我们先谈谈自我注意,我有:
x -> [batch_size, query_len, embedding_size]
现在在self-attention中查询的形状=键的形状=值的形状,所以
query = key = value = [batch_size, query_len, embedding_size]
self-attention -> softmax((query x key)/ square_root(key_len)) x value
在编码中,我很困惑在哪个维度应用 softmax 以及如何处理分数与值的乘法?我的意思是哪个维度彼此相乘?
def attention(self, query, key, value):
N = query.shape[0] # batch size
# context's shape -> [batch_size, 10, 768]
query_len, key_len, value_len = query.shape[1], key.shape[1], value.shape[1]
energy = torch.einsum("nqd,nkd->nqk", [query, key])
# [b, 10, 768]* [b, 10, 768] -> [b, 10, 10]
attention = torch.softmax(energy / (self.embed_dim ** (1/2)), dim= 2)
out = torch.einsum("nql, nld-> nqd", [attention, value])
return out, attention