从火炬中的张量中选择性替换向量的有效方法

时间:2019-11-23 11:17:41

标签: python pytorch tensor

给出一批文本序列,将其转换为张量,每个单词都使用单词嵌入或矢量(300维)表示。我需要用一组新的嵌入有选择地替换某些特定单词的向量。此外,这种替换将仅针对特定单词的所有出现而不会随机发生。目前,我有以下代码可以实现此目的。它使用2个for循环遍历每个单词,检查单词是否在指定的列表splIndices中。然后,它根据selected_中的T或F值检查是否需要替换该单词。

但这可以更有效地完成吗?

下面的代码可能不是MWE,但是我试图通过删除细节来简化代码,以便重点解决问题。请忽略代码的语义或用途,因为此代码段中可能未适当表示它。问题是关于提高性能。


splIndices = [45, 62, 2983, 456, 762]  # vocabulary indices which needs to be replaced
splFreqs = 2000  # assuming the words in splIndices occurs 2000 times
selected_ = Torch.Tensor(2000).uniform_(0, 1) > 0.2  # Tensor with 20% of the entries True
replIndexCtr = 0  # counter for selected_

# Dictionary with vectors to be replaced. This is a dummy function.
# Original function depends on some property of the word
diffVector = {45: Torch.Tensor(300).uniform_(0, 1), ...... 762: Torch.Tensor(300).uniform_(0, 1) } 

embeding = nn.Embedding.from_pretrained(embedding_matrix, freeze=False)
tempVals = x  # shape [32, 41] - batch of 32 sequences with 41 words each
x = embeding(x) # shape [32, 41, 300] - the sequence now has replaced vocab indices with embeddings

# iterate through batch for sequences
for i, item in enumerate(x):
    # iterate sequences for words
    for j, stuff in enumerate(item):
        if tempVals[i][j].item() in splIndices: 
            if self.selected_[replIndexCtr] == True:                   
                x[i,j] = diffVector[tempVals[i][j].item()]
                replIndexCtr += 1


1 个答案:

答案 0 :(得分:1)

可以通过以下方式将其矢量化:

num

以下是可能的输出示例:

import torch
import torch.nn as nn
import torch.nn.functional as F

batch_size, sentence_size, vocab_size, emb_size = 3, 2, 15, 1

# Make certain bias as a marker of embedding 
embedder_1 = nn.Linear(vocab_size, emb_size)
embedder_1.weight.data.fill_(0)
embedder_1.bias.data.fill_(200)

embedder_2 = nn.Linear(vocab_size, emb_size)
embedder_2.weight.data.fill_(0)
embedder_2.bias.data.fill_(404)

# Here are the indices of words which need different embdedding
replace_list = [3, 5, 7, 9] 

# Make a binary mask highlighing special words' indices
mask = torch.zeros(batch_size, sentence_size, vocab_size)
mask[..., replace_list] = 1

# Make random dataset
data_indices = torch.randint(0, vocab_size, (batch_size, sentence_size))
data_onehot = F.one_hot(data_indices, vocab_size)

# Check if onehot of a word collides with replace mask 
replace_mask = mask.long() * data_onehot
replace_mask = torch.sum(replace_mask, dim=-1).byte() # byte() is critical here

data_emb = torch.empty(batch_size, sentence_size, emb_size)

# Fill default embeddings
data_emb[1-replace_mask] = embedder_1(data_onehot[1-replace_mask].float())
if torch.max(replace_mask) != 0: # If not all zeros
    # Fill special embeddings
    data_emb[replace_mask] = embedder_2(data_onehot[replace_mask].float())

print(data_indices)
print(replace_mask)
print(data_emb.squeeze(-1).int())