我正在尝试使用Pipeline和GridSearchCV以及自定义回归函数。
我通过设置类来使用sklearn.svm.SVR工作,然后在设置GridSearchCV时从字典中传递参数。
使用我自己的回归函数(RegFn)时,我似乎无法做同样的事情。特别是我想在设置时将参数字典中的值传递给RegFn。
感谢有关如何使其发挥作用的任何指示。
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import GridSearchCV
from sklearn.pipeline import Pipeline
from sklearn.base import BaseEstimator,RegressorMixin
parameters = dict(regfn__i=i, regfn__j=j, regfn__k=k)
regfn = RegFn(numInputs=noinputs, numOutputs=1, i = 100, j = 1, k =0.5)
scaler = StandardScaler()
steps = [ ('scaler', scaler), ('regfn', regfn ) ]
grid = GridSearchCV(Pipeline(steps), param_grid=parameters, cv=splits, refit=True, verbose=3, n_jobs=1)
更新了课程结构
class RBF(BaseEstimator,RegressorMixin):
def __init__(self, i,j,k,numInputs=10,numOutputs=1):
self.numInputs = numInputs
self.numOutputs = numOutputs
self.i = i
self.j = j
self.k = k
.....
def fit(self, x, y):
....
def predict(self, x):
....
我仍然无法从字典中传递i,j,k值。我是以正确的方式提及它们吗?
答案 0 :(得分:2)
如果您想使用自定义估算工具it should be inherited from sklearn.base.BaseEstimator,则应实施方法predict()
和__init__()
。
如果您希望能够将参数传递给估算工具,则还应实施>>> import numpy as np
>>> from sklearn.base import BaseEstimator, RegressorMixin
>>> from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
>>> from sklearn.utils.multiclass import unique_labels
>>> from sklearn.metrics import euclidean_distances
>>> class TemplateClassifier(BaseEstimator, RegressorMixin):
...
... def __init__(self, demo_param='demo'):
... self.demo_param = demo_param
...
... def fit(self, X, y):
...
... # Check that X and y have correct shape
... X, y = check_X_y(X, y)
... # Store the classes seen during fit
... self.classes_ = unique_labels(y)
...
... self.X_ = X
... self.y_ = y
... # Return the classifier
... return self
...
... def predict(self, X):
...
... # Check is fit had been called
... check_is_fitted(self, ['X_', 'y_'])
...
... # Input validation
... X = check_array(X)
...
... closest = np.argmin(euclidean_distances(X, self.X_), axis=1)
... return self.y_[closest]
方法。
fp.cron.expr=0 15 0 ? * *
PS请注意,上面的示例显示了如何创建自定义分类器,因此它可能包含一个代码,这对于回归量来说不太合适