未能在Pytorch中训练SkipGram单词嵌入

时间:2017-10-13 23:04:26

标签: python-3.x deep-learning pytorch

我正在使用https://arxiv.org/abs/1310.4546中描述的着名模型训练skipgram字嵌入。我想在PyTorch训练它,但我得到错误,我无法弄清楚它们来自哪里。下面我提供了我的模型类,训练循环和批处理方法。有没有人对什么事情有所了解?

我在output = loss(data, target)行收到错误。由于CrossEntropyLoss需要很长的张量,因此<class 'torch.LongTensor'>出现问题很奇怪。输出形状可能是错误的,即:{| 1}}在前馈之后。

我将我的模型定义为:

torch.Size([1000, 100, 1000])

我的培训定义为:

import torch
import torch.nn as nn

torch.manual_seed(1)

class SkipGram(nn.Module):

    def __init__(self, vocab_size, embedding_dim):
        super(SkipGram, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.hidden_layer = nn.Linear(embedding_dim, vocab_size)

        # Loss needs to be input: (minibatch (N), C) target: (minibatch, 1), each label is a class
        # Calculate loss in training            

    def forward(self, x):
        embeds = self.embeddings(x)
        x = self.hidden_layer(embeds)
        return x

如果有用,我的批处理是:

import torch.optim as optim
from torch.autograd import Variable

net = SkipGram(1000, 300)
optimizer = optim.SGD(net.parameters(), lr=0.01)

batch_size = 100
size = len(train_ints)

batches = batch_index_gen(batch_size, size)
inputs, targets = build_tensor_from_batch_index(batches[0], train_ints)


for i in range(100):
    running_loss = 0.0

    for batch_idx, batch in enumerate(batches):
        data, target = build_tensor_from_batch_index(batch, train_ints)
#         if (torch.cuda.is_available()):
#             data, target = data.cuda(), target.cuda()
#             net = net.cuda()
        data, target = Variable(data), Variable(target)
        optimizer.zero_grad()
        output = net.forward(data)
        loss = nn.CrossEntropyLoss()    
        output = loss(data, target)
        output.backward()
        optimizer.step()                        
        running_loss += loss.data[0]
        optimizer.step()

        print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                i, batch_idx * len(batch_size), len(size),
                100. * (batch_idx * len(batch_size)) / len(size), loss.data[0]))

1 个答案:

答案 0 :(得分:0)

我不确定这一点,因为我没有读过这篇论文,但是你用原始数据和目标计算损失似乎很奇怪:

output = loss(data, target)

考虑到网络的输出是output = net.forward(data)我认为你应该计算你的损失:

error = loss(output, target)

如果这没有帮助,请简单地指出论文中有关损失函数的内容。