我正在尝试在张量流中构建神经网络,其中I型错误(假阳性)的成本比II型错误(假阴性)的成本更高。有没有办法在训练过程中强加这个(即输入成本矩阵)?通过指定class_weight参数,可以使用简单模型(例如scikit学习中的Logistic回归)进行建模。
cw = {0: 3,1:1}
clf = LogisticRegression(class_weight = cw )
在这种情况下,错误地预测0的成本比错误地预测1的成本高3倍。但是,这无法通过神经网络执行,因此我想看看在tensorflow中是否可行。
谢谢
答案 0 :(得分:0)
您可以使用tf.nn.weighted_cross_entropy_with_logits,它是pos_weight
的参数。
此自变量加权正类,如文档所述(至少在TF2.0中):
A value pos_weights > 1 decreases the false negative count, hence increasing the recall.
Conversely setting pos_weights < 1 decreases the false positive count and increases the precision.
根据您的情况,您可以创建自定义损失函数,如下所示:
import tensorflow as tf
# Output logits from your network, not the values after sigmoid activation
class WeightedBinaryCrossEntropy:
def __init__(self, positive_weight: float):
self.positive_weight = positive_weight
def __call__(self, targets, logits):
return tf.nn.weighted_cross_entropy_with_logits(
targets, logits, pos_weight=self.positive_weight
)
并使用它创建一个自定义的神经网络,例如使用tf.keras
(对样本的权重与问题相同:
import numpy as np
model = tf.keras.models.Sequential(
[
tf.keras.layers.Dense(32, input_shape=(10,)),
tf.keras.layers.Activation("relu"),
tf.keras.layers.Dense(10),
tf.keras.layers.Activation("relu"),
# Output one logit for binary classification
tf.keras.layers.Dense(1),
]
)
# Example random data
data = np.random.random((32, 10))
targets = np.random.randint(2, size=32)
# 3 times as costly to make type I error
model.compile(optimizer="rmsprop", loss=WeightedBinaryCrossEntropy(positive_weight=3))
model.fit(data, targets, batch_size=32)
答案 1 :(得分:-1)
您可以使用对数刻度。对于错误地预测为0的0,y - ŷ = -1
,对数为1.71。对于预测为0的1,y - ŷ = 1
对数等于0.63。对于y == ŷ
,对数等于0。代价几乎是原来的三倍,因为0被错误地预测为1。
import numpy as np
from math import exp
loss=abs(1-exp(-np.log(exp(y-ŷ))))
#abs(1-exp(-np.log(exp(0))))
#Out[53]: 0.0
#abs(1-exp(-np.log(exp(-1))))
#Out[54]: 1.718281828459045
#abs(1-exp(-np.log(exp(1))))
#Out[55]: 0.6321205588285577
然后您将获得凸优化。实施:
import keras.backend as K
def custom_loss(y_true,y_pred):
return K.mean(abs(1-exp(-np.log(exp(y_true-y_pred)))))
然后:
model.compile(loss=custom_loss, optimizer=sgd,metrics = ['accuracy'])