我的Keras模型不能预测负值

时间:2018-01-14 16:56:37

标签: python neural-network deep-learning keras regression

我想使用包含正负连续值的数据集来使用keras测试NN模型。 keras模型如下:

from keras.models import Sequential
from keras.layers import Dense
import numpy
#fix random seed for reproducibility
numpy.random.seed(7)

#load and read dataset
dataset = numpy.loadtxt("Phenols-toxicity.csv", delimiter=";")
# split into input (X) and output (Y) variables
X = dataset[:,2:4]
Y = dataset[:,1]
print ("Variables: \n", X)
print ("Target_outputs: \n", Y)
# create model
model = Sequential()
model.add(Dense(4, input_dim=2, activation='relu'))
#model.add(Dense(4, activation='relu'))
model.add(Dense(1, activation='relu'))
model.summary()
# Compile model
model.compile(loss='mean_squared_error', optimizer='sgd', metrics=['MSE'])
# Fit the model
model.fit(X, Y, epochs=500, batch_size=10)
#make predictions (test)
F = model.predict(X)
print ("Predicted values: \n", F)

一切似乎都很好,但是,所有负值都预测为零。程序soemwhere是否将值限制为正值? 我的目标值如下:

[ 0.085  2.468  0.07   0.68  -0.184  0.545 -0.063  0.871  0.113 -0.208
 0.688  1.638  2.03   0.078  0.573  1.036  0.015 -0.03  -0.381  0.701
 0.205  0.266  1.796  2.033  0.168  2.097  1.081 -0.384  0.377 -0.326
-0.143  1.292  0.701  0.334  1.157  1.638 -0.046  0.343  1.167  1.301
 0.277  1.131  0.471  0.617  0.707  0.185  0.604  0.017  0.381  0.804
 0.618  2.712 -0.092 -0.826  0.122  0.932  0.281  0.854  1.276  2.574
 1.125  0.73   0.796  1.145  1.569  2.664  0.034  1.398  0.393  0.612
-0.78   0.228 -1.043 -0.141  0.013  1.119  0.643 -0.242  0.757 -0.299
 0.599  0.36   1.778  0.053  1.268  1.276  0.516  1.167  1.638  0.478
 1.229  0.735  2.049 -0.064  1.201  1.41   1.295  0.798  1.854  0.16
-0.954  0.424 -0.51   1.638 -0.598  2.373  2.222 -0.358 -0.295  0.33
 0.183  0.122  1.745  0.081  2.097  0.914  0.979  0.084  0.473 -0.302
 0.879  0.366  0.172  0.45   1.307  0.886 -0.524  1.174 -0.512  0.939
 0.775 -1.053 -0.814  0.475 -1.021  1.42  -0.82   0.654  0.571 -0.076
 0.74   1.729  0.75   1.712  0.95   0.33   1.125  1.077  1.721  0.506
 0.539  0.266  1.745  1.229  0.632  1.585 -0.155  0.463  1.638  0.67
-0.155  2.053  0.379  0.181  0.253  1.356]

预测值如下:

[[ 0.        ]
 [ 2.03844833]
 [ 0.27423561]
 [ 0.59996957]
 [ 0.        ]
 [ 0.44271404]
 [ 0.        ]
 [ 0.47064281]
 [ 0.29890585]
 [ 0.        ]
 [ 0.95044041]
 [ 1.84322166]
 [ 1.93953323]
 [ 0.18019629]
 [ 0.68691438]
 [ 0.96168059]
 [ 0.13934678]
 [ 0.        ]
 [ 0.        ]
 [ 0.87886989]
 [ 0.30047321]
 [ 0.        ]
 [ 1.90942693]
 [ 1.83728123]
 [ 0.        ]
 [ 1.84627008]
 [ 1.25797462]
 [ 0.        ]
 [ 0.01434445]
 [ 0.        ]
 [ 0.        ]
 [ 1.1421392 ]
 [ 0.83652729]
 [ 0.37334418]
 [ 1.72099805]
 [ 1.73340106]
 [ 0.30456764]
 [ 0.        ]
 [ 1.37316585]
 [ 1.34221601]
 [ 0.6739701 ]
 [ 0.79646528]
 [ 0.03717542]
 [ 0.35218674]
 [ 0.09512168]
 [ 0.        ]
 [ 0.20107687]
 [ 0.        ]
 [ 0.01262379]
 [ 1.00669646]
 [ 0.96650052]
 [ 2.10064697]
 [ 0.        ]
 [ 0.        ]
 [ 0.25874525]
 [ 0.61007023]
 [ 0.68899512]
 [ 0.81215698]
 [ 0.88977867]
 [ 2.43740511]
 [ 1.00497019]
 [ 0.94933379]
 [ 0.83326894]
 [ 0.63394952]
 [ 1.27170706]
 [ 2.56578207]
 [ 0.        ]
 [ 1.29493976]
 [ 0.599581  ]
 [ 0.63211834]
 [ 0.        ]
 [ 0.31536853]
 [ 0.        ]
 [ 0.        ]
 [ 0.02201092]
 [ 0.84008563]
 [ 0.73076451]
 [ 0.        ]
 [ 0.4879511 ]
 [ 0.        ]
 [ 0.77698141]
 [ 0.66419512]
 [ 1.56657863]
 [ 0.25022489]
 [ 1.36990726]
 [ 1.50250816]
 [ 0.        ]
 [ 0.61219454]
 [ 0.87011993]
 [ 0.72275633]
 [ 1.36519527]
 [ 0.72287238]
 [ 2.3798852 ]
 [ 0.        ]
 [ 1.23592615]
 [ 1.43725252]
 [ 0.95585048]
 [ 0.63723856]
 [ 1.8765614 ]
 [ 0.31583393]
 [ 0.        ]
 [ 0.14386666]
 [ 0.        ]
 [ 1.68151355]
 [ 0.        ]
 [ 1.63394952]
 [ 1.97563386]
 [ 0.        ]
 [ 0.        ]
 [ 0.38875413]
 [ 0.18854523]
 [ 0.23547113]
 [ 1.13463831]
 [ 0.30076784]
 [ 1.61114097]
 [ 0.93304199]
 [ 1.04891086]
 [ 0.26546735]
 [ 0.62234318]
 [ 0.        ]
 [ 0.        ]
 [ 0.21855426]
 [ 0.        ]
 [ 0.        ]
 [ 0.        ]
 [ 0.        ]
 [ 0.        ]
 [ 0.39396375]
 [ 0.45845711]
 [ 0.        ]
 [ 0.        ]
 [ 0.        ]
 [ 0.        ]
 [ 0.        ]
 [ 0.        ]
 [ 0.4718284 ]
 [ 0.        ]
 [ 0.        ]
 [ 0.91218936]
 [ 0.        ]
 [ 0.82205164]
 [ 0.78155482]
 [ 0.98432505]
 [ 2.15232277]
 [ 0.97631133]
 [ 0.59527659]
 [ 0.83814716]
 [ 0.80036032]
 [ 1.17462301]
 [ 0.51232517]
 [ 0.82968521]
 [ 0.9463613 ]
 [ 1.69353771]
 [ 1.21046495]
 [ 1.36349583]
 [ 0.94378138]
 [ 0.        ]
 [ 0.98034143]
 [ 1.66670561]
 [ 0.52768588]
 [ 0.93855476]
 [ 1.26870298]
 [ 0.        ]
 [ 0.        ]
 [ 0.        ]
 [ 1.69362605]]

[[ 0.        ]
 [ 2.03844833]
 [ 0.27423561]
 [ 0.59996957]
 [ 0.        ]
 [ 0.44271404]
 [ 0.        ]
 [ 0.47064281]
 [ 0.29890585]
 [ 0.        ]
 [ 0.95044041]
 [ 1.84322166]
 [ 1.93953323]
 [ 0.18019629]
 [ 0.68691438]
 [ 0.96168059]
 [ 0.13934678]
 [ 0.        ]
 [ 0.        ]
 [ 0.87886989]
 [ 0.30047321]
 [ 0.        ]
 [ 1.90942693]
 [ 1.83728123]
 [ 0.        ]
 [ 1.84627008]
 [ 1.25797462]
 [ 0.        ]
 [ 0.01434445]
 [ 0.        ]
 [ 0.        ]
 [ 1.1421392 ]
 [ 0.83652729]
 [ 0.37334418]
 [ 1.72099805]
 [ 1.73340106]
 [ 0.30456764]
 [ 0.        ]
 [ 1.37316585]
 [ 1.34221601]
 [ 0.6739701 ]
 [ 0.79646528]
 [ 0.03717542]
 [ 0.35218674]
 [ 0.09512168]
 [ 0.        ]
 [ 0.20107687]
 [ 0.        ]
 [ 0.01262379]
 [ 1.00669646]
 [ 0.96650052]
 [ 2.10064697]
 [ 0.        ]
 [ 0.        ]
 [ 0.25874525]
 [ 0.61007023]
 [ 0.68899512]
 [ 0.81215698]
 [ 0.88977867]
 [ 2.43740511]
 [ 1.00497019]
 [ 0.94933379]
 [ 0.83326894]
 [ 0.63394952]
 [ 1.27170706]
 [ 2.56578207]
 [ 0.        ]
 [ 1.29493976]
 [ 0.599581  ]
 [ 0.63211834]
 [ 0.        ]
 [ 0.31536853]
 [ 0.        ]
 [ 0.        ]
 [ 0.02201092]
 [ 0.84008563]
 [ 0.73076451]
 [ 0.        ]
 [ 0.4879511 ]
 [ 0.        ]
 [ 0.77698141]
 [ 0.66419512]
 [ 1.56657863]
 [ 0.25022489]
 [ 1.36990726]
 [ 1.50250816]
 [ 0.        ]
 [ 0.61219454]
 [ 0.87011993]
 [ 0.72275633]
 [ 1.36519527]
 [ 0.72287238]
 [ 2.3798852 ]
 [ 0.        ]
 [ 1.23592615]
 [ 1.43725252]
 [ 0.95585048]
 [ 0.63723856]
 [ 1.8765614 ]
 [ 0.31583393]
 [ 0.        ]
 [ 0.14386666]
 [ 0.        ]
 [ 1.68151355]
 [ 0.        ]
 [ 1.63394952]
 [ 1.97563386]
 [ 0.        ]
 [ 0.        ]
 [ 0.38875413]
 [ 0.18854523]
 [ 0.23547113]
 [ 1.13463831]
 [ 0.30076784]
 [ 1.61114097]
 [ 0.93304199]
 [ 1.04891086]
 [ 0.26546735]
 [ 0.62234318]
 [ 0.        ]
 [ 0.        ]
 [ 0.21855426]
 [ 0.        ]
 [ 0.        ]
 [ 0.        ]
 [ 0.        ]
 [ 0.        ]
 [ 0.39396375]
 [ 0.45845711]
 [ 0.        ]
 [ 0.        ]
 [ 0.        ]
 [ 0.        ]
 [ 0.        ]
 [ 0.        ]
 [ 0.4718284 ]
 [ 0.        ]
 [ 0.        ]
 [ 0.91218936]
 [ 0.        ]
 [ 0.82205164]
 [ 0.78155482]
 [ 0.98432505]
 [ 2.15232277]
 [ 0.97631133]
 [ 0.59527659]
 [ 0.83814716]
 [ 0.80036032]
 [ 1.17462301]
 [ 0.51232517]
 [ 0.82968521]
 [ 0.9463613 ]
 [ 1.69353771]
 [ 1.21046495]
 [ 1.36349583]
 [ 0.94378138]
 [ 0.        ]
 [ 0.98034143]
 [ 1.66670561]
 [ 0.52768588]
 [ 0.93855476]
 [ 1.26870298]
 [ 0.        ]
 [ 0.        ]
 [ 0.        ]
 [ 1.69362605]]

1 个答案:

答案 0 :(得分:6)

是的,您正在将负面约束为零。输出激活是一个ReLU,就是这样。

解决方案只是将输出激活更改为产生负数的输出,如tanh。请注意,该激活的范围是[-1,1],因此您必须将输出标签规范化为相同的范围。