PyTorch:在交替优化中是否需要keep_graph = True?

时间:2019-12-09 19:34:31

标签: python machine-learning pytorch mixture-model

我正在尝试使用PyTorch以交替方式优化两个模型。第一个是神经网络,它正在更改我的数据的表示形式(即输入数据x上的映射f(x),由权重W进行参数设置)。第二个是在f(x)点(即神经网络空间(而不是输入空间中的聚类点))上运行的高斯混合模型。我正在使用期望最大化来优化GMM,因此可以解析地得出参数更新,而不是使用梯度下降。

我在这里有两个损失函数:第一个是距离|| f(x)-f(y)||的函数,第二个是高斯混合模型的损失函数(即如何“聚类”)一切都在NN表示空间中找到)。我想做的是使用上述两个损失函数进行神经网络优化(因为它取决于这两个函数),然后对GMM做一个期望最大化的步骤。代码看起来像这样(由于有大量代码,所以我删除了很多东西):

data, labels = load_dataset()
net = NeuralNetwork()
net_optim = torch.optim.Adam(net.parameters(), lr=0.05, weight_decay=1)

# initialize weights, means, and covariances for the Gaussian clusters
concentrations, means, covariances, precisions = initialization(net.forward_one(data)) 

for i in range(1000):
    net_optim.zero_grad()
    pairs, pair_labels = pairGenerator(data, labels) # samples some pairs of datapoints
    outputs = net(pairs[:, 0, :], pairs[:, 1, :]) # computes pairwise distances
    net_loss = NeuralNetworkLoss(outputs, pair_labels) # loss function based on pairwise dist.

    embedding = net.forward_one(data) # embeds all data in the NN space

    log_prob, log_likelihoods = expectation_step(embedding, means, precisions, concentrations)
    concentrations, means, covariances, precisions = maximization_step(embedding, log_likelihoods)

    gmm_loss = GMMLoss(log_likelihoods, log_prob, precisions, concentrations)

    net_loss.backward(retain_graph=True)
    gmm_loss.backward(retain_graph=True)
    net_optim.step()

本质上,这是正在发生的事情:

  1. 从数据集中采样一些点对
  2. 通过神经网络推动点对并根据这些输出计算网络损耗
  3. 使用NN嵌入所有数据点,并在该嵌入空间中执行聚类EM步骤
  4. 基于聚类参数的计算变异损失(ELBO)
  5. 同时使用变异损失和网络损失更新神经网络参数

但是,要执行(5),我需要添加标志retain_graph=True,否则会出现错误:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

似乎有两个损失函数意味着我需要保留计算图?

我不确定如何解决此问题,例如retain_graph=True,在迭代400左右,每次迭代大约需要30分钟才能完成。有谁知道我该如何解决?我预先表示歉意-我对自动区分还是很陌生。

1 个答案:

答案 0 :(得分:0)

我建议这样做

total_loss = net_loss + gmm_loss
total_loss.backward()

请注意,net_loss w.r.t gmm权重的梯度为0,因此将损失相加不会产生任何影响。 这是pytorch上与rest_graph有关的一个好线程。 https://discuss.pytorch.org/t/what-exactly-does-retain-variables-true-in-loss-backward-do/3508/24