sklearn SVM RBF过度拟合sin波

时间:2018-03-28 19:44:21

标签: python machine-learning scikit-learn regression svm

我正在使用SVM从嘈杂的输入中学习正弦波。我已经尝试了许多不同的超参数,但该模型似乎仍然过度拟合。我不确定我是否只是要求模型做更多的事情,或者我只是选择了错误的超参数。也许有一个更好的模型我可以用于这个任务?如果有帮助,我可以显示输入和输出的图形。这是我的代码:

import numpy as np
from sklearn.svm import SVR
import matplotlib.pyplot as plt

def svr(x, y, C=1, gamma=.1):
    # train/test split
    split = int(y.shape[0] * 0.8)
    x_train, x_test = x[:split], x[split:]
    y_train, y_test = y[:split], y[split:]

    # fit SVM
    svrm = SVR(kernel='rbf', C=C, gamma=gamma)  
    svr_fit = svrm.fit(x_train, y_train.flatten())

    # make predictions on train and test sets
    y_fitted = svr_fit.predict(x_train)
    y_predicted = svr_fit.predict(x_test)
    print('Done!')

    return y_fitted, y_predicted

lin = np.linspace(0, 100, 5000)
rand = np.random.random(lin.shape)
sin = np.sin(lin)

x = (sin + rand/2 - 0.25).reshape((-1, 1))
y = sin.reshape((-1, 1))
print(x.shape, y.shape)

plt.plot(lin, sin)
plt.show()
plt.plot(lin[:500], x.flatten()[:500])
plt.show()

y_fit, y_pred = svr(x, y, C=1, gamma=.1)

plt.plot(lin[:len(y_fit)], y_fit)
plt.show()
plt.plot(lin[len(y_fit):], y_pred)
plt.show()

0 个答案:

没有答案