我想计算自己的损失。我想通过深度学习来预测N点。因此,网络的输出为N点(N * 3)。 numpy计算应为:
import numpy as np
point1 = np.random.random(size=[10, 30, 3])
point2 = np.random.random(size=[10, 30, 3])
losses = []
for s in range(10):
loss = 0
for p in range(30):
p1 = point1[s, p, :]
dis = p1 - point2[s, :, :]
dis = np.linalg.norm(dis, axis=1)
loss += dis.min()
losses.append(loss)
print(loss)
在pytorch中,重点应该是:
point1 = np.random.random(size=[10, 30, 3])
point2 = np.random.random(size=[10, 30, 3])
point1 = torch.from_numpy(point1)
point2 = torch.from_numpy(point2)
如何计算火炬的损失?
任何建议都值得赞赏!