如何避免重新训练机器学习模型

时间:2015-01-17 23:34:35

标签: python machine-learning scikit-learn

这里的自学者。

我正在构建一个预测事件的Web应用程序。

让我们考虑一下这个简单的例子。

X = [[0], [1], [2], [3]]
y = [0, 0, 1, 1]
from sklearn.neighbors import KNeighborsClassifier
neigh = KNeighborsClassifier(n_neighbors=3)
neigh.fit(X, y) 

print(neigh.predict([[1.1]]))

当我输入neigh之类的新值时,如何保持neigh.predict([[1.2]])的状态,我不需要重新训练模型。有没有好的做法,或暗示开始解决问题?

2 个答案:

答案 0 :(得分:7)

由于几个原因,您选择了一个有点混乱的例子。首先,当您说neigh.predict([[1.2]])时,您没有添加新的训练点,您只是在进行新的预测,因此根本不需要任何更改。其次,KNN算法并没有真正受过训练,并且#34; - KNN是一种instance-based算法,这意味着"培训"相当于将训练数据存储在合适的结构中。因此,这个问题有两个不同的答案。我会先尝试回答KNN问题。

K最近邻居

对于KNN,添加新的训练数据相当于将新数据点附加到结构中。但是,似乎scikit-learn没有提供任何此类功能。 (这是合理的 - 因为KNN明确存储每个训练点,你不能无限期地继续给它新的训练点。)

如果您没有使用许多培训点,那么简单的列表可能足以满足您的需求!在这种情况下,您可以完全跳过sklearn,只需将新数据点附加到列表中即可。要进行预测,请进行线性搜索,保存k最近邻居,然后根据简单的"多数投票进行预测。 - 如果五个邻居中有三个或更多是红色,则返回红色,依此类推。但请记住,您添加的每个训练点都会降低算法速度。

如果您需要使用许多训练点,您将希望使用更有效的结构进行最近邻搜索,例如K-D Tree。有scipy K-D Tree实现应该有效。 query方法允许您查找k最近邻居。它比列表更有效,但随着您添加更多训练数据,它仍然会变慢。

在线学习

对你的问题更一般的回答是,你(自己并不知道)试图做一些名为online learning的事情。在线学习算法允​​许您在到达时使用各个训练点,并在使用后将其丢弃。为了理所当然,你需要不是存储训练点本身(如在KNN中),而是存储一组优化的参数。

这意味着某些算法比其他算法更适合这种算法。 sklearn只提供了一些算法capable of online learning。这些都有一个partial_fit方法,允许您批量传递训练数据。导致'hinge''log'丢失的SKDClassifier可能是一个很好的起点。

答案 1 :(得分:5)

或者您可能只想在安装后保存模型

joblib.dump(neigh, FName)

并在需要时加载

neigh = joblib.load(FName)
neigh.predict([[1.1]])