tensorflow的model.fit sample_weight参数的功能与sklearn的指标sample_weight不同吗?

时间:2020-05-31 07:06:23

标签: python tensorflow

我看到在TF中为model.fit()提供样本权重时,看起来每个样本损失只是乘以样本权重,然后求和,然后除以样本数。就是提供较小量值的样本重量值也会导致损失的实际值也变小,其他所有情况都相同。这是有道理的。

但是对于sklearn的指标(至少为log_loss),情况似乎并非如此。就我而言,提供相同的权重似乎会增加损失值的规模。我希望这两个库的行为类似。

我该如何匹配TF计算的功能?

import tensorflow as tf
from sklearn.metrics import log_loss

y_true = [[0, 1, 0], [0, 0, 1]]
y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
sample_weight = [0.3, 0.7]
cce = tf.keras.losses.CategoricalCrossentropy()

# produces same output
print(cce(y_true, y_pred).numpy())
print(log_loss(y_true, y_pred))

# produces different output
print(cce(y_true, y_pred, sample_weight=tf.constant(sample_weight)).numpy())
print(log_loss(y_true, y_pred, sample_weight=sample_weight))

0 个答案:

没有答案