我正在阅读log_loss和交叉熵,看起来有两种计算方法
首先
import numpy as np
from sklearn.metrics import log_loss
def cross_entropy(predictions, targets):
N = predictions.shape[0]
ce = -np.sum(targets*np.log(predictions))/N
return ce
predictions = np.array([[0.25,0.25,0.25,0.25],
[0.01,0.01,0.01,0.97]])
targets = np.array([[1,0,0,0],
[0,0,0,1]])
x = cross_entropy(predictions, targets)
print(log_loss(targets, predictions), 'our_answer:', ans)
输出:0.7083767843022996 our_answer: 0.71355817782
,几乎相同。所以这不是问题。
来源:http://wiki.fast.ai/index.php/Log_Loss
以上实施是等式的中间部分。
第二种:计算方法是等式的RHS部分:
res = 0
for act_row, pred_row in zip(targets, np.array(predictions)):
for class_act, class_pred in zip(act_row, pred_row):
res += - class_act * np.log(class_pred) - (1-class_act) * np.log(1-class_pred)
print(res/len(targets))
输出:1.1549753967602232
哪个不太一样。我尝试过与numpy相同的实现也没有用。我做错了什么?
PS:我也很好奇-y log (y_hat)
在我看来它与- sigma(p_i * log( q_i))
相同,那么为什么会有-(1-y) log(1-y_hat)
部分。很明显,我误解了如何计算-y log (y_hat)
。
信用:代码借鉴自:Cross entropy function (python)
答案 0 :(得分:3)
无法重现您在第一部分中报告的结果的差异(您还可以参考ans
变量,您似乎没有定义它,我猜它是x
):
import numpy as np
from sklearn.metrics import log_loss
def cross_entropy(predictions, targets):
N = predictions.shape[0]
ce = -np.sum(targets*np.log(predictions))/N
return ce
predictions = np.array([[0.25,0.25,0.25,0.25],
[0.01,0.01,0.01,0.97]])
targets = np.array([[1,0,0,0],
[0,0,0,1]])
结果:
cross_entropy(predictions, targets)
# 0.7083767843022996
log_loss(targets, predictions)
# 0.7083767843022996
log_loss(targets, predictions) == cross_entropy(predictions, targets)
# True
您的cross_entropy
功能似乎正常。
关于第二部分:
显然,我误解了如何计算
-y log (y_hat)
。
事实上,仔细阅读你所链接的fast.ai wiki,你会发现等式的RHS只适用于二元分类(y
和1-y
之一总是如此是零),这不是这种情况 - 你有一个4级多项分类。所以,正确的表述是
res = 0
for act_row, pred_row in zip(targets, np.array(predictions)):
for class_act, class_pred in zip(act_row, pred_row):
res += - class_act * np.log(class_pred)
即。丢弃(1-class_act) * np.log(1-class_pred)
的减法。
结果:
res/len(targets)
# 0.7083767843022996
res/len(targets) == log_loss(targets, predictions)
# True
在更一般的层面(对数丢失的机制和二进制分类的准确性),您可能会发现this answer有用。