pytorch神经网络进行分类

时间:2019-03-01 13:23:37

标签: python machine-learning neural-network deep-learning pytorch

任务是训练神经网络在给定输入 x 的列表中选择值 v ,以将(x-v)最小化。 例如,我有100x1的整数维张量作为输入。我的神经网络应该输出这些值的概率分布。

它对于1x1张量正常工作,但在输入中为每个示例输出相同的分布。我希望它分别为每个示例输出适当的分布。我该如何解决?

代码如下:

N, D_in, D_out, H = 20, 1, 5, 10
X = torch.LongTensor(100, 1).random_(0, 500).float()
values = torch.FloatTensor([0, 1, 3, 5, 100])

model = nn.Sequential(
    nn.Linear(D_in, H),
    nn.Sigmoid(),
    nn.Linear(H, D_out),
    nn.Softmax(dim=1)
)

def loss_function(x, y, probability):
    return torch.mean(torch.sum(probability*((x-y)**2), 1).div(probability.size(1)))

loss_fn = loss_function
learning_rate = 1e-3
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

model.train(True)
for epoch in range(200):
    pred = model(X)
    loss = loss_fn(X, values, pred)
    print('loss: ', loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

它是20个示例的输入和输出: 输入:

tensor([[ 16.],
         [ 28.],
         [188.],
         [ 16.],
         [379.],
         [366.],
         [366.],
         [421.],
         [ 13.],
         [ 86.],
         [ 38.],
         [135.],
         [173.],
         [153.],
         [223.],
         [133.],
         [316.],
         [108.],
         [133.],
         [295.]])

输出:

tensor([[1.2926e-04, 1.3165e-04, 1.2761e-04, 1.3385e-04, 9.9948e-01],
         [9.9158e-05, 1.0086e-04, 9.8492e-05, 1.0170e-04, 9.9960e-01],
         [3.3826e-05, 3.4437e-05, 3.4443e-05, 3.2122e-05, 9.9987e-01],
         [1.2926e-04, 1.3165e-04, 1.2761e-04, 1.3385e-04, 9.9948e-01],
         [3.3173e-05, 3.3778e-05, 3.3758e-05, 3.1419e-05, 9.9987e-01],
         [3.3175e-05, 3.3780e-05, 3.3759e-05, 3.1421e-05, 9.9987e-01],
         [3.3175e-05, 3.3780e-05, 3.3759e-05, 3.1421e-05, 9.9987e-01],
         [3.3170e-05, 3.3775e-05, 3.3754e-05, 3.1416e-05, 9.9987e-01],
         [1.4026e-04, 1.4281e-04, 1.3817e-04, 1.4524e-04, 9.9943e-01],
         [4.4578e-05, 4.5341e-05, 4.5195e-05, 4.3355e-05, 9.9982e-01],
         [8.1636e-05, 8.2994e-05, 8.1468e-05, 8.2810e-05, 9.9967e-01],
         [3.5991e-05, 3.6629e-05, 3.6643e-05, 3.4392e-05, 9.9986e-01],
         [3.4158e-05, 3.4773e-05, 3.4784e-05, 3.2472e-05, 9.9986e-01],
         [3.4881e-05, 3.5505e-05, 3.5521e-05, 3.3232e-05, 9.9986e-01],
         [3.3425e-05, 3.4031e-05, 3.4026e-05, 3.1694e-05, 9.9987e-01],
         [3.6153e-05, 3.6794e-05, 3.6807e-05, 3.4561e-05, 9.9986e-01],
         [3.3191e-05, 3.3796e-05, 3.3777e-05, 3.1439e-05, 9.9987e-01],
         [3.9212e-05, 3.9897e-05, 3.9870e-05, 3.7751e-05, 9.9984e-01],
         [3.6153e-05, 3.6794e-05, 3.6807e-05, 3.4561e-05, 9.9986e-01],
         [3.3207e-05, 3.3812e-05, 3.3795e-05, 3.1457e-05, 9.9987e-01]],
        grad_fn=<SoftmaxBackward>)

0 个答案:

没有答案