Keras中示例暹罗网络的准确性和错误率

时间:2020-04-11 19:44:45

标签: machine-learning keras deep-learning siamese-network

我一直在跟踪此示例here,并且想知道此精度函数的工作原理:

def compute_accuracy(y_true, y_pred):
'''Compute classification accuracy with a fixed threshold on distances.
'''
    pred = y_pred.ravel() < 0.5
    return np.mean(pred == y_true)

据我所知,这种情况下网络的输出将是两对之间的距离。那么在这种情况下我们如何计算精度呢? “ 0.5”阈值是什么意思?另外,如何计算错误率?

1 个答案:

答案 0 :(得分:1)

在理解该示例时,似乎有一些空白需要首先填补:

如果您研究数据准备步骤(即create_pairs方法),您将意识到正数对(即,属于同一类别的成对样本)被分配了标签1(即正数/真)负数对(即,属于不同类别的样本对)的标号为0(即负数/假)。

此外,本示例中的暹罗网络设计为,如果给定一对样本作为输入,它将预测其距离作为输出。通过使用对比损失作为模型的损失函数,对模型进行训练,从而预测给定正对作为输入,可以预测出较小的距离值(因为它们属于同一类,因此它们的距离应该很短,即传达相似性),并在输入为负对的情况下,可以预测到较大的距离值(因为它们属于差异类,因此它们的距离应较高,即传达不相似性)。作为练习,请尝试使用代码中的对比损失定义,通过数值考虑这些点(即y_true为1且y_true为0时)来确认这些点。

因此,示例中的精度函数实现为:将固定的任意阈值(即0.5)应用于预测的距离值(即y_pred)(这意味着本文的作者例如,示例确定距离值小于0.5表示正对;您可以决定使用另一个阈值,但是根据实验/经验,这应该是一个合理的选择)。然后,将结果与真实标签值(即y_true

)进行比较
  • y_pred小于0.5(y_pred < 0.5等于True)时:如果y_true为1(即正数),则表示预测该网络的真实值与真实标签一致(即True == 1等于True),因此该样本的预测被计入正确的预测(即准确性)中。但是,如果y_true为0(即负),则此样本的预测不正确(即True == 0等于False),因此这将无助于正确的预测。 / p>

  • y_pred等于或大于0.5(y_pred < 0.5等于False)时:适用与上述相同的理由(请留作练习!)。

  • p>

注意:请不要忘记对一批样本训练模型。因此,y_predy_true并不是一个单一的值;它们是值数组,以及上面提到的所有计算/比较都是按元素进行的。

让我们看一下输入的5个样本对批次上的(虚拟)数值示例,以及如何计算该批次上的模型预测的准确性:

>>> y_pred = np.array([1.5, 0.7, 0.1, 0.3, 3.2])
>>> y_true = np.array([1, 0, 0, 1, 0])

>>> pred = y_pred < 0.5
>>> pred
array([False, False,  True,  True, False])

>>> result = pred == y_true
>>> result
array([False,  True, False,  True,  True])

>>> accuracy = np.mean(result)
>>> accuracy
0.6