使用MLPRegressor解决简单数据问题

时间:2016-12-09 22:31:48

标签: python scikit-learn neural-network

我正在尝试Python和scikit-learn。我无法让MLPRegressor接近数据。哪里出错了?

from sklearn.neural_network import MLPRegressor
import numpy as np
import matplotlib.pyplot as plt

x = np.arange(0.0, 1, 0.01).reshape(-1, 1)
y = np.sin(2 * np.pi * x).ravel()

reg = MLPRegressor(hidden_layer_sizes=(10,),  activation='relu', solver='adam',    alpha=0.001,batch_size='auto',
               learning_rate='constant', learning_rate_init=0.01, power_t=0.5, max_iter=1000, shuffle=True,
               random_state=None, tol=0.0001, verbose=False, warm_start=False, momentum=0.9,
               nesterovs_momentum=True, early_stopping=False, validation_fraction=0.1, beta_1=0.9, beta_2=0.999,
               epsilon=1e-08)

reg = reg.fit(x, y)

test_x = np.arange(0.0, 1, 0.05).reshape(-1, 1)
test_y = reg.predict(test_x)

fig = plt.figure()
ax1 = fig.add_subplot(111)
ax1.scatter(x, y, s=10, c='b', marker="s", label='real')
ax1.scatter(test_x,test_y, s=10, c='r', marker="o", label='NN Prediction')

plt.show()

结果不是很好: failed fit 谢谢。

2 个答案:

答案 0 :(得分:10)

这个非非线性模型的点数太少,因此拟合对种子很敏感。一颗好种子有帮助,但它不是先验的。您还可以添加更多数据点。

通过迭代各种种子,我确定printf()能够很好地运作。当然还有其他人。

random_state=9

enter image description here

以下是种子整数from sklearn.neural_network import MLPRegressor import numpy as np import matplotlib.pyplot as plt x = np.arange(0.0, 1, 0.01).reshape(-1, 1) y = np.sin(2 * np.pi * x).ravel() nn = MLPRegressor( hidden_layer_sizes=(10,), activation='relu', solver='adam', alpha=0.001, batch_size='auto', learning_rate='constant', learning_rate_init=0.01, power_t=0.5, max_iter=1000, shuffle=True, random_state=9, tol=0.0001, verbose=False, warm_start=False, momentum=0.9, nesterovs_momentum=True, early_stopping=False, validation_fraction=0.1, beta_1=0.9, beta_2=0.999, epsilon=1e-08) n = nn.fit(x, y) test_x = np.arange(0.0, 1, 0.05).reshape(-1, 1) test_y = nn.predict(test_x) fig = plt.figure() ax1 = fig.add_subplot(111) ax1.scatter(x, y, s=1, c='b', marker="s", label='real') ax1.scatter(test_x,test_y, s=10, c='r', marker="o", label='NN Prediction') plt.show() 的拟合的绝对误差:

i = 0..9

产生:

print(i, sum(abs(test_y - np.sin(2 * np.pi * test_x).ravel())))

现在,我们仍然可以通过将目标点的数量从100增加到1000以及将隐藏层的大小从10增加到100来改善即使0 13.0874999193 1 7.2879574143 2 6.81003360188 3 5.73859777885 4 12.7245375367 5 7.43361211586 6 7.04137436733 7 7.42966661997 8 7.35516939164 9 2.87247035261 的拟合:

random_state=0

产量: enter image description here

顺便说一下,from sklearn.neural_network import MLPRegressor import numpy as np import matplotlib.pyplot as plt x = np.arange(0.0, 1, 0.001).reshape(-1, 1) y = np.sin(2 * np.pi * x).ravel() nn = MLPRegressor( hidden_layer_sizes=(100,), activation='relu', solver='adam', alpha=0.001, batch_size='auto', learning_rate='constant', learning_rate_init=0.01, power_t=0.5, max_iter=1000, shuffle=True, random_state=0, tol=0.0001, verbose=False, warm_start=False, momentum=0.9, nesterovs_momentum=True, early_stopping=False, validation_fraction=0.1, beta_1=0.9, beta_2=0.999, epsilon=1e-08) n = nn.fit(x, y) test_x = np.arange(0.0, 1, 0.05).reshape(-1, 1) test_y = nn.predict(test_x) fig = plt.figure() ax1 = fig.add_subplot(111) ax1.scatter(x, y, s=1, c='b', marker="s", label='real') ax1.scatter(test_x,test_y, s=10, c='r', marker="o", label='NN Prediction') plt.show() 中不需要某些参数,例如MLPRegressor()momentum等。检查文档。此外,它有助于为您的示例添加种子以确保结果可重现;)

答案 1 :(得分:10)

你只需要

  • 将解算器更改为'lbfgs'。默认'adam'是一种类似SGD的方法,对大型和大型有效。凌乱的数据,但对于这种平滑的&小数据。
  • 使用平滑的激活功能,例如tanhrelu几乎是线性的,不适合学习这种简单的非线性函数。

这里是result和完整代码。即使只有3个隐藏的神经元也可以达到很高的准确度。

from sklearn.neural_network import MLPRegressor
import numpy as np
import matplotlib.pyplot as plt

x = np.arange(0.0, 1, 0.01).reshape(-1, 1)
y = np.sin(2 * np.pi * x).ravel()

nn = MLPRegressor(hidden_layer_sizes=(3), 
                  activation='tanh', solver='lbfgs')

n = nn.fit(x, y)
test_x = np.arange(-0.1, 1.1, 0.01).reshape(-1, 1)
test_y = nn.predict(test_x)
fig = plt.figure()
ax1 = fig.add_subplot(111)
ax1.scatter(x, y, s=5, c='b', marker="o", label='real')
ax1.plot(test_x,test_y, c='r', label='NN Prediction')

plt.legend()
plt.show()