Tensorflow到PyTorch-model.predict等效

时间:2020-07-09 19:44:18

标签: python tensorflow pytorch

  1. 我正在尝试检索训练的均方误差。在基于TensorFlow的原始代码中,我将这段代码移至PyTorch(出于研究原因)。

    原始TensorFlow代码:

    print("Calculating threshold")
    x_opt_predictions = model.predict(x_opt)
    print("Calculating MSE on optimization set...")
    mse = np.mean(np.power(x_opt - x_opt_predictions, 2), axis=1)
    print("mean is %.5f" % mse.mean())
    print("min is %.5f" % mse.min())
    print("max is %.5f" % mse.max())
    print("std is %.5f" % mse.std())
    tr = mse.mean() + mse.std()

火炬的训练方法:

def train(net, x_train, x_opt, BATCH_SIZE, EPOCHS, input_dim):
    outputs = 0
    mse = 0
    optimizer = optim.SGD(net.parameters(), lr=0.001)
    loss_function = nn.MSELoss()
    loss = 0
    for epoch in range(EPOCHS):
        for i in tqdm(range(0, len(x_train), BATCH_SIZE)):
            batch_y = x_opt[i:i + BATCH_SIZE]
            
            net.zero_grad()
            
            outputs = net(batch_y)
            

            loss = loss_function(outputs, batch_y)
            loss.backward()
            optimizer.step()

        print(f"Epoch: {epoch}. Loss: {loss}")
        print("opt", x_opt.size(), "output", outputs.__sizeof__())

    # VVVVVVVVVVVVVVVVVVVVVVVVVVVVVV
    return np.mean(np.power(x_opt - outputs, 2), axis=1)
    # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    在行“输出”上方看到的
  1. 不是预测的数字数组,获取等效项以生成阈值

  2. 如果还有其他(改善或缺失)的方法来获取此价值,请提前进行升值。

1 个答案:

答案 0 :(得分:0)

变量输出是一个pytorch张量,将其转换为numpy,您需要更改的只是将代码行return np.mean(np.power(x_opt - outputs, 2), axis=1)更改为return np.mean(np.power(x_opt - outputs.cpu().data.numpy(), 2), axis=1)即可将张量转换为numpy数组。如果您的网络中没有使用cuda,则不需要.cpu()部分。