我想重建一个MLP,我先用scikit-learn的MLPRegressor和tflearn实现。
sklearn.neural_network.MLPRegressor实施:
train_data = pd.read_csv('train_data.csv', delimiter = ';', decimal = ',', header = 0)
test_data = pd.read_csv('test_data.csv', delimiter = ';', decimal = ',', header = 0)
X_train = np.array(train_data.drop(['output'], 1))
X_scaler = StandardScaler()
X_scaler.fit(X_train)
X_train = X_scaler.transform(X_train)
Y_train = np.array(train_data['output'])
clf = MLPRegressor(activation = 'tanh', solver='lbfgs', alpha=0.0001, hidden_layer_sizes=(3))
clf.fit(X_train, Y_train)
prediction = clf.predict(X_train)
该模型有效,我的准确度为0.85
。现在我想用tflearn构建一个类似的MLP。我从以下代码开始:
train_data = pd.read_csv('train_data.csv', delimiter = ';', decimal = ',', header = 0)
test_data = pd.read_csv('test_data.csv', delimiter = ';', decimal = ',', header = 0)
X_train = np.array(train_data.drop(['output'], 1))
X_scaler = StandardScaler()
X_scaler.fit(X_train)
X_train = X_scaler.transform(X_train)
Y_train = np.array(train_data['output'])
Y_scaler = StandardScaler()
Y_scaler.fit(Y_train)
Y_train = Y_scaler.transform(Y_train.reshape((-1,1)))
net = tfl.input_data(shape=[None, 6])
net = tfl.fully_connected(net, 3, activation='tanh')
net = tfl.fully_connected(net, 1, activation='sigmoid')
net = tfl.regression(net, optimizer='sgd', loss='mean_square', learning_rate=3.)
clf = tfl.DNN(net)
clf.fit(X_train, Y_train, n_epoch=200, show_metric=True)
prediction = clf.predict(X_train)
在某些时候,我确实以错误的方式配置了一些东西,因为预测是偏离的。 Y_train的范围介于20
和88
之间,预测显示0.005
周围的数字。在tflearn文档中,我刚刚找到了分类的例子。
我意识到回归层默认使用'categorical_crossentropy'
作为损失函数,用于分类问题。所以我选择了'mean_square'
。我还尝试将Y_train
标准化。预测仍然不匹配Y_train
的范围。有什么想法吗?
看一下接受的答案。
答案 0 :(得分:0)
一步应该是不缩放输出。 我也在研究回归问题,我只对输入进行扩展,并且它可以与一些神经网络一起工作。虽然如果我使用tflearn我会得到错误的预测。
答案 1 :(得分:0)
我做了几个真正愚蠢的错误。
首先,我将输出调整到0
到1
的间隔,但在输出层使用了激活函数tanh
,它将-1
的值传递给{ {1}}。所以我不得不使用激活函数来输出1
和0
之间的值(例如1
)或sigmoid
,而不应用任何缩放。
其次,最重要的是,对于我的数据,我为linear
和learning rate
选择了一个相当糟糕的组合。我认为,我没有指定任何学习率,默认值为n_epoch
。无论如何它太小了(我最终使用0.1
)。与此同时,纪元数量(3.0
)也太小了,10
它的工作正常。
我还明确选择了200
作为sgd
(默认:optimizer
),结果证明效果更好。
我在我的问题中更新了代码。