我正在使用数据集训练线性回归模型,该数据集在区间[0,10]中具有实值标签。我在测试集上的预测值有一些超过10的预测。有没有办法将预测限制为10.
我正在考虑进行条件检查,如果预测超过10,我明确将其设置为10.
有更好的方法吗?
答案 0 :(得分:6)
如果y
是回归对象的predict
方法的输出,那么您可以将Numpy的minimum
限制为10:
y = np.minimum(y, 10.)
要将其置于零以下,请执行
y = np.maximum(np.minimum(y, 10.), 0.)
或者更短:
y = np.clip(y, 0., 10.)