任务是训练神经网络在给定输入 x 的列表中选择值 v ,以将(x-v)最小化。 例如,我有100x1的整数维张量作为输入。我的神经网络应该输出这些值的概率分布。
它对于1x1张量正常工作,但在输入中为每个示例输出相同的分布。我希望它分别为每个示例输出适当的分布。我该如何解决?
代码如下:
N, D_in, D_out, H = 20, 1, 5, 10
X = torch.LongTensor(100, 1).random_(0, 500).float()
values = torch.FloatTensor([0, 1, 3, 5, 100])
model = nn.Sequential(
nn.Linear(D_in, H),
nn.Sigmoid(),
nn.Linear(H, D_out),
nn.Softmax(dim=1)
)
def loss_function(x, y, probability):
return torch.mean(torch.sum(probability*((x-y)**2), 1).div(probability.size(1)))
loss_fn = loss_function
learning_rate = 1e-3
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
model.train(True)
for epoch in range(200):
pred = model(X)
loss = loss_fn(X, values, pred)
print('loss: ', loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
它是20个示例的输入和输出: 输入:
tensor([[ 16.],
[ 28.],
[188.],
[ 16.],
[379.],
[366.],
[366.],
[421.],
[ 13.],
[ 86.],
[ 38.],
[135.],
[173.],
[153.],
[223.],
[133.],
[316.],
[108.],
[133.],
[295.]])
输出:
tensor([[1.2926e-04, 1.3165e-04, 1.2761e-04, 1.3385e-04, 9.9948e-01],
[9.9158e-05, 1.0086e-04, 9.8492e-05, 1.0170e-04, 9.9960e-01],
[3.3826e-05, 3.4437e-05, 3.4443e-05, 3.2122e-05, 9.9987e-01],
[1.2926e-04, 1.3165e-04, 1.2761e-04, 1.3385e-04, 9.9948e-01],
[3.3173e-05, 3.3778e-05, 3.3758e-05, 3.1419e-05, 9.9987e-01],
[3.3175e-05, 3.3780e-05, 3.3759e-05, 3.1421e-05, 9.9987e-01],
[3.3175e-05, 3.3780e-05, 3.3759e-05, 3.1421e-05, 9.9987e-01],
[3.3170e-05, 3.3775e-05, 3.3754e-05, 3.1416e-05, 9.9987e-01],
[1.4026e-04, 1.4281e-04, 1.3817e-04, 1.4524e-04, 9.9943e-01],
[4.4578e-05, 4.5341e-05, 4.5195e-05, 4.3355e-05, 9.9982e-01],
[8.1636e-05, 8.2994e-05, 8.1468e-05, 8.2810e-05, 9.9967e-01],
[3.5991e-05, 3.6629e-05, 3.6643e-05, 3.4392e-05, 9.9986e-01],
[3.4158e-05, 3.4773e-05, 3.4784e-05, 3.2472e-05, 9.9986e-01],
[3.4881e-05, 3.5505e-05, 3.5521e-05, 3.3232e-05, 9.9986e-01],
[3.3425e-05, 3.4031e-05, 3.4026e-05, 3.1694e-05, 9.9987e-01],
[3.6153e-05, 3.6794e-05, 3.6807e-05, 3.4561e-05, 9.9986e-01],
[3.3191e-05, 3.3796e-05, 3.3777e-05, 3.1439e-05, 9.9987e-01],
[3.9212e-05, 3.9897e-05, 3.9870e-05, 3.7751e-05, 9.9984e-01],
[3.6153e-05, 3.6794e-05, 3.6807e-05, 3.4561e-05, 9.9986e-01],
[3.3207e-05, 3.3812e-05, 3.3795e-05, 3.1457e-05, 9.9987e-01]],
grad_fn=<SoftmaxBackward>)