我可以设置神经网络的预测值吗?

时间:2019-06-24 13:05:52

标签: python neural-network pytorch

我的问题是理论性的,而不是实用的,但是我也可以显示一些代码。我有从域uvw中的随机值映射到xyz域的网络。我希望uvw的某个值能移到我已经知道的xyz中的其他某些值,因为那是我想要获得的功能的想法,并且我希望网络过度拟合。

我的问题分为两个问题:

  1. 我可以将所需的预测值设置到网络中,这样就不必计算那些预测值了吗?
  2. 这是否会影响其他值的预测?

这是我的代码,我想显示它,因此我们可以讨论一些符号。

 # The model is a simple fully connected network mapping a 3D parameter point to 3D
    phi = common.MLP(in_dim=3, out_dim=3).to(args.device)

    # Eps is 1/lambda and max_iters is the maximum number of Sinkhorn iterations to do
    emd_loss_fun = SinkhornLoss(eps=args.sinkhorn_eps, max_iters=args.max_sinkhorn_iters,
                                stop_thresh=1e-3, return_transport_matrix=True) # TODO add r-1 function to the weights  

    mse_loss_fun = torch.nn.MSELoss() 


    # Adam optimizer at first
    optimizer = torch.optim.Rprop(phi.parameters(), lr=args.learning_rate)

    fit_start_time = time.time()

    for epoch in range(args.num_epochs):
        optimizer.zero_grad()

        # Do the forward pass of the neural net, evaluating the function at the parametric points
        y = phi(t)

        # Compute the Sinkhorn divergence between the reconstruction*(using the francis library) and the target
        # NOTE: The Sinkhorn function expects a batch of b point sets (i.e. tensors of shape [b, n, 3])
        # since we only have 1, we unsqueeze so x and y have dimension [1, n, 3]
        with torch.no_grad():
            _, M = emd_loss_fun(phi(t[num:]).unsqueeze(0), x[num:].unsqueeze(0))
            _, Q = emd_loss_fun(phi(t[0:num]).unsqueeze(0), x[0:num].unsqueeze(0)) 
            P[0,num:,num:] = M[0]
            P[0,0:num,0:num] = Q[0]
            #print(y[Q.squeeze().max(0)[1], :])





        # Project the transport matrix onto the space of permutation matrices and compute the L-2 loss
        # between the permuted points
        loss =  mse_loss_fun(y[P.squeeze().max(0)[1], :], x)
        # loss = mse_loss_fun(P.squeeze() @ y,  x)  # Use the transport matrix directly

        # Take an optimizer step
        loss.backward()
        optimizer.step()

        print("Epoch %d, loss = %f" % (epoch, loss.item()))

0 个答案:

没有答案