不更新模型权重

时间:2019-05-29 12:43:47

标签: python machine-learning pytorch

我刚刚创建了一个简单的模型,当我运行代码时,它应该为每个时期更新循环并更新函数的权重。但是,它可以运行,但是在打印pytorch版本后不会继续。

您可以检查我的代码,我的循环没发现任何问题,但不会更新权重。我缺少一些重要的信息以使网络继续运行吗?

# x is a tensor of shape [n, 3] containing the positions of the vertices that
   x = torch.from_numpy(common.loadpointcloud().astype(np.float32))
   # t is a tensor of shape [n, 3] containing a set of nicely distributed samples in the unit cube
   v, f = test.unit_cube()
   t = torch.from_numpy(pcu.sample_mesh_lloyd(v,f,x.shape[0]).astype(np.float32)) # sample randomly a point cloud (cube for now?)

   # The model is a simple fully connected network mapping a 3D parameter point to 3D
   phi = common.MLP(in_dim=3, out_dim=3)
   phi.cuda()

   # Eps is 1/lambda and max_iters is the maximum number of Sinkhorn iterations to do
   emd_loss_fun = SinkhornLoss(eps=1e-3, max_iters=x.shape[0],
                               stop_thresh=1e-3, return_transport_matrix=True)

   mse_loss_fun = torch.nn.MSELoss()

   # Adam optimizer at first
   optimizer = torch.optim.Adam(phi.parameters(), lr= 10e-3)
   fit_start_time = time.time()

   for epoch in range(100):
       optimizer.zero_grad()
       # Do the forward pass of the neural net, evaluating the function at the parametric points
       y = phi(t)

       # Compute the Sinkhorn divergence between the reconstruction*(using the francis library) and the target
       # NOTE: The Sinkhorn function expects a batch of b point sets (i.e. tensors of shape [b, n, 3])
       # since we only have 1, we unsqueeze so x and y have dimension [1, n, 3]
       with torch.no_grad():
           _, P = emd_loss_fun(phi(t).unsqueeze(0), x.unsqueeze(0))

       # Project the transport matrix onto the space of permutation matrices and compute the L-2 loss
       # between the permuted points
       loss = mse_loss_fun(y[P.squeeze().max(0)[1], :], x)
       # loss = mse_loss_fun(P.squeeze() @ y,  x)  # Use the transport matrix directly

       # Take an optimizer step
       loss.backward()
       optimizer.step()
       print("Epoch %d, loss = %f" % (epoch, loss.item()))

   fit_end_time = time.time()

它应该运行并创建一个将t映射到x的函数。但是,该程序仅显示以下内容:

python3 main2.py
1.1.0

而且我什么也做不了,甚至无法ctrl + c停止程序。

0 个答案:

没有答案