我正尝试使用Keras神经网络预测值的数组,如下所示:
def create_network():
np.random.seed(0)
number_of_features = (22)
# start neural network
network = models.Sequential()
# Adding three layers
# Add fully connected layer with ReLU
network.add(layers.Dense(units=35, activation="relu", input_shape=(number_of_features,)))
# Add fully connected layer with ReLU
network.add(layers.Dense(units=35, activation='relu'))
# Add fully connected layer with no activation function
network.add(layers.Dense(units=1))
# Compile neural network
network.compile(loss='mse',
optimizer='RMSprop',
metrics=['mse'])
# Return compiled network
return network
neural_network = KerasClassifier(build_fn=create_network, epochs = 10, batch_size = 100, verbose = 0)
neural_network.fit(x_train, y_train)
neural_network.predict(x_test)
使用代码进行预测时,我得到以下输出数据:
array([[-0.23991031],
[-0.23991031],
[-0.23991031],
[-0.23991031],
[-0.23991031],
[-0.23991031],
[-0.23991031],
[-0.23991031],
[-0.23991031],
[-0.23991031],
[-0.23991031],
[-0.23991031],
[-0.23991031],
[-0.23991031],
[-0.23991031],
[-0.23991031],
[-0.23991031],
[-0.23991031],
[-0.23991031],
[-0.23991031],
[-0.23991031]])
此数组中的不同值不应相同,因为输入数据点不相同。为什么会这样?
数据被分为训练和测试,如下所示:
x_train, x_test, y_train, y_test = train_test_split(df3, df2['RETURN NEXT 12 MONTHS'], test_size=0.2)# 0.2 means 20% of the values are used for testing
以下是一些数据样本:
x_train
RETURN ON INVESTED CAPITAL ... VOLATILITY
226 0.0436 ... 0.3676
309 0.1073 ... 0.3552
306 0.1073 ... 0.3660
238 0.1257 ... 0.4352
254 0.1960 ... 0.4230
308 0.1073 ... 0.3661
327 0.2108 ... 0.2674
325 0.2108 ... 0.2836
...
上面的数据框有22列,大约100行。
相应的y训练数据为:
y_train
226 0.137662
309 1.100000
306 0.725738
238 0.244292
254 -0.557806
308 1.052402
327 -0.035730
...
我尝试使用不同数量的纪元,batch_size和不同的模型体系结构,但是它们为所有输入都提供相同的输出。
每纪元损失:
Epoch 1/10
10/100 [==>...........................] - ETA: 1s - loss: 1525.8176 - mse: 1525.8176
100/100 [==============================] - 0s 2ms/step - loss: 13771.8389 - mse: 13771.8389
Epoch 2/10
10/100 [==>...........................] - ETA: 0s - loss: 4315.0015 - mse: 4315.0015
30/100 [========>.....................] - ETA: 0s - loss: 23554.2446 - mse: 23554.2441
40/100 [===========>..................] - ETA: 0s - loss: 18089.7297 - mse: 18089.7305
50/100 [==============>...............] - ETA: 0s - loss: 15002.7878 - mse: 15002.7871
100/100 [==============================] - 0s 2ms/step - loss: 10520.1019 - mse: 10520.1025
Epoch 3/10
10/100 [==>...........................] - ETA: 0s - loss: 2722.1135 - mse: 2722.1135
100/100 [==============================] - 0s 167us/step - loss: 8500.4698 - mse: 8500.4697
Epoch 4/10
10/100 [==>...........................] - ETA: 0s - loss: 3192.2231 - mse: 3192.2231
50/100 [==============>...............] - ETA: 0s - loss: 4860.0622 - mse: 4860.0620
90/100 [==========================>...] - ETA: 0s - loss: 7377.6898 - mse: 7377.6904
100/100 [==============================] - 0s 1ms/step - loss: 6911.2499 - mse: 6911.2500
Epoch 5/10
10/100 [==>...........................] - ETA: 0s - loss: 1996.5687 - mse: 1996.5687
60/100 [=================>............] - ETA: 0s - loss: 6902.6661 - mse: 6902.6660
70/100 [====================>.........] - ETA: 0s - loss: 6162.6467 - mse: 6162.6470
90/100 [==========================>...] - ETA: 0s - loss: 6195.4129 - mse: 6195.4131
100/100 [==============================] - 0s 4ms/step - loss: 5773.2919 - mse: 5773.2925
Epoch 6/10
10/100 [==>...........................] - ETA: 0s - loss: 3063.1946 - mse: 3063.1946
80/100 [=======================>......] - ETA: 0s - loss: 5351.2784 - mse: 5351.2793
90/100 [==========================>...] - ETA: 0s - loss: 5100.9203 - mse: 5100.9214
100/100 [==============================] - 0s 2ms/step - loss: 4755.7785 - mse: 4755.7793
Epoch 7/10
10/100 [==>...........................] - ETA: 0s - loss: 3710.8032 - mse: 3710.8032
70/100 [====================>.........] - ETA: 0s - loss: 4607.9606 - mse: 4607.9609
100/100 [==============================] - 0s 943us/step - loss: 3847.0730 - mse: 3847.0732
Epoch 8/10
10/100 [==>...........................] - ETA: 0s - loss: 1742.0632 - mse: 1742.0632
30/100 [========>.....................] - ETA: 0s - loss: 2304.5816 - mse: 2304.5818
100/100 [==============================] - 0s 1ms/step - loss: 3109.5293 - mse: 3109.5293
Epoch 9/10
10/100 [==>...........................] - ETA: 0s - loss: 2027.7537 - mse: 2027.7537
100/100 [==============================] - 0s 574us/step - loss: 2537.4794 - mse: 2537.4795
Epoch 10/10
10/100 [==>...........................] - ETA: 0s - loss: 2966.5125 - mse: 2966.5125
100/100 [==============================] - 0s 177us/step - loss: 2191.3686 - mse: 2191.3687