pytorch nn.CrossEntropyLoss()中的交叉熵损失

时间:2017-11-01 23:21:16

标签: torch entropy pytorch loss cross-entropy

也许有人能够在这里帮助我。我正在尝试计算网络给定输出的交叉熵损失

*

和所需的标签,格式为

print output
Variable containing:
1.00000e-02 *
-2.2739  2.9964 -7.8353  7.4667  4.6921  0.1391  0.6118  5.2227  6.2540     
-7.3584
[torch.FloatTensor of size 1x10]

其中x是0到9之间的整数。 根据pytorch文档(http://pytorch.org/docs/master/nn.html

print lab
Variable containing:
x
[torch.FloatTensor of size 1]

这应该可行,但不幸的是我得到了一个奇怪的错误

criterion = nn.CrossEntropyLoss()
loss = criterion(output, lab)

任何人都可以帮助我吗?我真的很困惑,并尝试了几乎所有我能想到的有用的东西。

最佳

1 个答案:

答案 0 :(得分:5)

请检查此代码

data['T'] = pd.to_datetime(data['T'])
plt.figure(figsize=(9, 5))
plt.plot(data['T'],data['Close'], lw=1, label='ARK CLOSE')
plt.plot(data['T'],data['MA_5'], 'g', lw=1, label='5-day SMA (green)')
plt.plot(data['T'],data['MA_20'], 'y', lw=2, label='20-day SMA (red)')

plt.setp(plt.gca().get_xticklabels(), rotation=30)
plt.show()

这将很好地打印出损失:

import torch
import torch.nn as nn
from torch.autograd import Variable

output = Variable(torch.rand(1,10))
target = Variable(torch.LongTensor([1]))

criterion = nn.CrossEntropyLoss()
loss = criterion(output, target)
print(loss)