我具有以下用于CBOW嵌入的PyTorch模型:
class CBOW(nn.Module):
def __init__(self, vocab_size, context_radius, embedding_dim):
super(CBOW, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.linear1 = nn.Linear(2 * context_radius * embedding_dim, 128)
self.linear2 = nn.Linear(128, vocab_size)
def forward(self, input):
output1 = self.embedding(input).view(1, -1)
output2 = F.relu(self.linear1(output1))
output3 = self.linear2(output2)
return F.log_softmax(output3, dim=1)
当我向其发送尺寸为(2*context_radius)
的张量时,它工作得很好,但是当我向(N, 2*context_radius)
发送其中N> 1的张量时,效果很好。例如,如果我的上下文半径为2(因此上下文为4 ),则可以发送输入[36, 30, 12, 7]
,并得到形状为(1, vocab_size)
的输出。但是我希望能够发送一批上下文并检索一批输出。