矩阵级别的注意力计算

时间:2021-01-30 09:05:50

标签: machine-learning nlp pytorch attention-model

我对自我关注或仅关注的计算感到困惑。

让我们先谈谈自我注意,我有:

x -> [batch_size, query_len, embedding_size]

现在在self-attention中查询的形状=键的形状=值的形状,所以

query = key = value = [batch_size, query_len, embedding_size]

self-attention -> softmax((query x key)/ square_root(key_len)) x value

在编码中,我很困惑在哪个维度应用 softmax 以及如何处理分数与值的乘法?我的意思是哪个维度彼此相乘?

def attention(self, query, key, value):
    N = query.shape[0] # batch size
    # context's shape -> [batch_size, 10, 768]
    query_len, key_len, value_len = query.shape[1], key.shape[1], value.shape[1]

    energy = torch.einsum("nqd,nkd->nqk", [query, key])
    # [b, 10, 768]* [b, 10, 768] -> [b, 10, 10]
    attention = torch.softmax(energy / (self.embed_dim ** (1/2)), dim= 2)

    out = torch.einsum("nql, nld-> nqd", [attention, value])

    return out, attention

0 个答案:

没有答案