是否可以从Sklean库中添加预测功能?以及如何做到?
def monomial(a,b):
return lambda x : a * math.pow(x,b)
返回构成所需顺序的多项式的单项式列表
def polyList(order):
return [monomial(1,i) for i in range(0,order+1)]
返回给定输入的功能总和
def evaluate(functionList, x):
return sum([f(x) for f in functionList])
返回加权和,即w0f0 + w1f1 + ...
def weightedSum(w,F):
if(len(w) != len(F)):
raise Exception("Function/weight size mismatch")
else:
return lambda x:sum([w[i]*F[i](x) for i in range(0,len(w))])
############
在这里,我们将给定阶数的多项式与权重的最大似然估计拟合。
def polyTrain(x,y,order):
#Initialize the weight vector and design matrix
w = [1 for i in range(0,order)]
F = polyList(order)
design = [[f(i) for f in F] for i in x]
#Convert them to numpy arrays
w = numpy.asarray(w)
design = numpy.asarray(design)
#We solve Ax=b, [x values x 3][coefficients]T = [yvalues]
pinv = numpy.linalg.pinv(design)
t = numpy.asarray(y).T
#We know that the ML estimates for w are w* = pinv(design)y.T
w = numpy.dot(pinv,t)
return weightedSum(w,F)
答案 0 :(得分:2)
最好定义一个可以处理所有逻辑的类。不过,如果你想 要编写完全符合scikit-learn中使用的fit-transform-predict协议的代码,您需要 从某个基类继承 scikit学习类,例如BaseEstimator,TransformerMixin,BaseRegressor。
Numpy提供了非常方便的功能vander,可以极大地帮助您 当您使用多项式时。
让我们定义一个类。
class PolyRegressor: # I omit subclassing for now.
def __init__(self, weights=None):
self.weights = np.array(weights) if weights is not None else None
@property
def order(self):
return len(self.weights) if self.weights is not None else 0
def evaluate(self, x):
return np.dot(np.vander(x, self.order), self.weights[:, np.newaxis]).ravel()
def fit(self, X, y=None):
self.weights = (np.linalg.pinv(np.vander(X, self.order)) @ y[:, np.newaxis]).ravel()
def predict(self, X):
if self.weights is not None:
return self.evaluate(X)
else:
raise Exception("Model wasn't fitted. Fit model first. ")
def fit_predict(self, X, y=None):
self.fit(X, y)
return self.predict(X)
reg = PolyRegressor()
reg.weights = np.array([1,2,3]) # we implicitly define order = 2 here, e.g. 3 + 2x + 1x^2
reg.evaluate(np.array([5])) # testing
array([38])#输出
reg.fit_predict(np.random.rand(10), np.random.rand(10) * 5)
array([2.55922997,1.81433623,2.29153779,1.78458414,1.75961514, 2.59770317、2.65122647、1.81313616、2.61993941、2.63325695])
根据您的需要采用代码。希望有帮助...