使用scikit限制线性回归预测值

时间:2012-03-17 22:01:50

标签: python linear-regression scikits

我正在使用数据集训练线性回归模型,该数据集在区间[0,10]中具有实值标签。我在测试集上的预测值有一些超过10的预测。有没有办法将预测限制为10.

我正在考虑进行条件检查,如果预测超过10,我明确将其设置为10.

有更好的方法吗?

1 个答案:

答案 0 :(得分:6)

如果y是回归对象的predict方法的输出,那么您可以将Numpy的minimum限制为10:

y = np.minimum(y, 10.)

要将其置于零以下,请执行

y = np.maximum(np.minimum(y, 10.), 0.)

或者更短:

y = np.clip(y, 0., 10.)