为什么pytorch模型无法识别我定义的张量?

时间:2020-02-13 09:56:49

标签: python machine-learning neural-network pytorch

我最近刚学pytorch。 而且,我尝试编写与我已阅读并练习的论文相同的模型。

这是我引用的论文的PDF。 https://dl.acm.org/doi/pdf/10.1145/3178876.3186066?download=true

这是我写的代码。

class Tem(torch.nn.Module):
    def __init__(self, embedding_size, hidden_size):
        super(Tem, self).__init()
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size
        self.leaf_size = 0

        self.xgb_model = None
        self.vec_embedding = None

        self.multi_hot_Q = None
        self.user_embedding = torch.nn.Linear(1, embedding_size)
        self.item_embedding = torch.nn.Linear(1, embedding_size)

    def pretrain(self, ui_attributes, labels):
        print("Start XGBoost Training...")
        self.xgb_model = XGBoost(ui_attributes, labels)
        self.leaf_size = self.xgb_model.leaf_size
        self.vec_embedding = Variable(torch.rand(self.embedding_size, self.leaf_size, requires_grad=True))

        self.h = Variable(torch.rand(self.hidden_size, 1, requires_grad=True))
        self.att_w = Variable(torch.rand(2 * self.embedding_size, self.hidden_size, requires_grad=True))
        self.att_b = Variable(torch.rand(self.leaf_size, self.hidden_size, requires_grad=True))

        self.r_1 = Variable(torch.rand(self.embedding_size, 1, requires_grad=True))
        self.r_2 = Variable(torch.rand(self.embedding_size, 1, requires_grad=True))
        self.bias = Variable(torch.rand(1, 1, requires_grad=True))

    def forward(self, ui_ids, ui_attributes):
        if self.xgb_model == None:
            raise Exception("Please run Tem.pretrain() to pre-train XGBoost model first.")

        n_data = len(ui_ids)

        att_input = torch.FloatTensor(ui_attributes)
        self.multi_hot_Q = torch.FloatTensor(self.xgb_model.multi_hot(att_input)).permute(0,2,1)

        vq = self.vec_embedding * self.multi_hot_Q

        id_input = torch.FloatTensor(ui_ids)

        user_embedded = self.user_embedding(id_input[:,0].reshape(n_data, 1))

        item_embedded = self.item_embedding(id_input[:,1].reshape(n_data, 1))

        ui = (user_embedded * item_embedded).reshape(n_data, self.embedding_size, 1)

        ui_repeat = ui.repeat(1, 1, self.leaf_size)

        cross = torch.cat([ui_repeat, vq], dim=1).permute(0,2,1)

        re_cross = corss.reshape(cross.shape[0] * cross.shape[1], cross.shape[2])

        attention = torch.mm(re_cross, self.att_w)
        attention = F.leaky_relu(attention + self.att_b.repeat(n_data, 1))
        attention = torch.mm(attention, self.h).reshape(n_data, self.leaf_size)
        attention = F.softmax(attention).reshape(n_data, self.leaf_size, 1)
        attention = self.vec_embedding.permute(1,0) * attention.repeat(1,1,20)

        pool = torch.max(attention, 1).values

        y_hat = self.bias.repeat(n_data, 1) + torch.mm(ui.reshape(n_data, self.embedding_size), self.r_1) + torch.mm(pool, self.r_2)
        y_hat = F.softmax(torch.nn.Linear(1, 2)(y_hat))
        return y_hat

我的问题是...似乎火炬不知道在向后传播中应该计算哪个张量梯度。

print(tem)
Tem(
(user_embedding): Linear(in_features=1, out_features=20, bias=True)
(item_embedding): Linear(in_features=1, out_features=20, bias=True)
)

我在这个问题上进行了搜索,有人说这些张量应该使用torch.autograd.Variable(),但这并不能解决我的问题。有人说autograd现在直接支持张量。 torch.autograd.Variable()是不必要的。非常感谢,如果您知道如何解决它。

loss_func = torch.nn.CrossEntropyLoss()
optimizer = torch.Adagrad(tem.parameters(), lr=0.02)
for t in range(20):
    prediction = tem(ids_train, att_train)

    loss = loss_func(prediction, y_train)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if t % 5 == 0:
        print("loss: ", loss)
loss: tensor(0.8133, grad_fn=<NllLossBackward>)
loss: tensor(0.8133, grad_fn=<NllLossBackward>)
loss: tensor(0.8133, grad_fn=<NllLossBackward>)
loss: tensor(0.8133, grad_fn=<NllLossBackward>)

1 个答案:

答案 0 :(得分:2)

您的问题与Variable无关。如您所说,这不再是必需的。要计算在模型中声明的(扩展nn.Module的张量的梯度),您需要使用方法nn.Parameter()将它们包括在模型的参数中。例如,要包含self.h,您可以执行以下操作:

self.h = nn.Parameter(torch.zeros(10,10)

现在,当您调用loss.backward()时,它将收集此变量的梯度(当然,loss必须依赖于self.h)。

我希望这会有所帮助。