> 0.3在model.predict语句中有什么作用?

时间:2019-07-11 05:00:27

标签: python input keras deep-learning predict

我在将RNN模型的2层输入拟合后,遇到了以下model.predict代码格式。这是在喀拉拉邦执行的。

y_pred = (model.predict(x=[X_test_pad, X_test_psl])>0.3).astype(np.int32)

我对深度学习模型非常陌生,不熟悉我们指定>0.3的格式或输入的任何值。因此,在这种情况下,我想知道>0.3的确切作用。

在这方面的任何澄清将不胜感激。

2 个答案:

答案 0 :(得分:2)

> 0.3不属于model.predict函数。 将预测结果与0.3的决策阈值进行比较。

示例: 脑图像中的肿瘤检测

如果您的模型说40%的人确定我患有肿瘤,我不会因为0.4 <0.5->没有肿瘤而幸福地回家。

相反,我将使用决策阈值与模型结果进行比较。在这种情况下,我们希望该模型与30%的机会相关联的所有输出都被认为是肯定的。

您可以这样写:

model_pred = model.predict(x=[X_test_pad, X_test_psl]
y_pred = (model_pred>0.3).astype(np.int32)

有关更多信息,我建议阅读https://stats.stackexchange.com/questions/312119/classification-probability-threshold

答案 1 :(得分:1)

predict方法返回类似数组的值。基本上,您正在将预测中的每一行(浮点数,可能是概率)与0.3的阈值进行比较。您可以想象它像列表理解一样工作。它将返回另一个类似于布尔值的数组结构,指示每个预测是否超过0.3。对于最后一步,您基本上是通过转换true / false值将此布尔数组转换为整数数组。从数学上讲,您可以认为它像步进函数一样工作。

我想如果将代码解构成较小的单元,将会更容易理解代码的工作方式。首先,检查预测的输出,然后检查比较的输出,依此类推。