给出一批文本序列,将其转换为张量,每个单词都使用单词嵌入或矢量(300维)表示。我需要用一组新的嵌入有选择地替换某些特定单词的向量。此外,这种替换将仅针对特定单词的所有出现而不会随机发生。目前,我有以下代码可以实现此目的。它使用2个for循环遍历每个单词,检查单词是否在指定的列表splIndices
中。然后,它根据selected_
中的T或F值检查是否需要替换该单词。
但这可以更有效地完成吗?
下面的代码可能不是MWE,但是我试图通过删除细节来简化代码,以便重点解决问题。请忽略代码的语义或用途,因为此代码段中可能未适当表示它。问题是关于提高性能。
splIndices = [45, 62, 2983, 456, 762] # vocabulary indices which needs to be replaced
splFreqs = 2000 # assuming the words in splIndices occurs 2000 times
selected_ = Torch.Tensor(2000).uniform_(0, 1) > 0.2 # Tensor with 20% of the entries True
replIndexCtr = 0 # counter for selected_
# Dictionary with vectors to be replaced. This is a dummy function.
# Original function depends on some property of the word
diffVector = {45: Torch.Tensor(300).uniform_(0, 1), ...... 762: Torch.Tensor(300).uniform_(0, 1) }
embeding = nn.Embedding.from_pretrained(embedding_matrix, freeze=False)
tempVals = x # shape [32, 41] - batch of 32 sequences with 41 words each
x = embeding(x) # shape [32, 41, 300] - the sequence now has replaced vocab indices with embeddings
# iterate through batch for sequences
for i, item in enumerate(x):
# iterate sequences for words
for j, stuff in enumerate(item):
if tempVals[i][j].item() in splIndices:
if self.selected_[replIndexCtr] == True:
x[i,j] = diffVector[tempVals[i][j].item()]
replIndexCtr += 1
答案 0 :(得分:1)
可以通过以下方式将其矢量化:
num
以下是可能的输出示例:
import torch
import torch.nn as nn
import torch.nn.functional as F
batch_size, sentence_size, vocab_size, emb_size = 3, 2, 15, 1
# Make certain bias as a marker of embedding
embedder_1 = nn.Linear(vocab_size, emb_size)
embedder_1.weight.data.fill_(0)
embedder_1.bias.data.fill_(200)
embedder_2 = nn.Linear(vocab_size, emb_size)
embedder_2.weight.data.fill_(0)
embedder_2.bias.data.fill_(404)
# Here are the indices of words which need different embdedding
replace_list = [3, 5, 7, 9]
# Make a binary mask highlighing special words' indices
mask = torch.zeros(batch_size, sentence_size, vocab_size)
mask[..., replace_list] = 1
# Make random dataset
data_indices = torch.randint(0, vocab_size, (batch_size, sentence_size))
data_onehot = F.one_hot(data_indices, vocab_size)
# Check if onehot of a word collides with replace mask
replace_mask = mask.long() * data_onehot
replace_mask = torch.sum(replace_mask, dim=-1).byte() # byte() is critical here
data_emb = torch.empty(batch_size, sentence_size, emb_size)
# Fill default embeddings
data_emb[1-replace_mask] = embedder_1(data_onehot[1-replace_mask].float())
if torch.max(replace_mask) != 0: # If not all zeros
# Fill special embeddings
data_emb[replace_mask] = embedder_2(data_onehot[replace_mask].float())
print(data_indices)
print(replace_mask)
print(data_emb.squeeze(-1).int())