Tensorflow-为什么我的ANN模型不学习

时间:2020-08-22 18:13:24

标签: python tensorflow deep-learning

这是我非常基本的ANN代码:

import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras import Sequential
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler,normalize

data = pd.read_csv("home_data.csv")
x = data.drop(['id', 'date', 'price'], axis=1).values
y = data['price'].values

x_train, x_test, y_train, y_test = train_test_split(x,y,test_size=0.33)

model = Sequential()
model.add(Dense(18, input_shape=(18,), activation="sigmoid"))
model.add(Dense(36, input_shape=(18,), activation="sigmoid"))
model.add(Dense(1, input_shape=(18,), activation="sigmoid"))
model.compile(optimizer='sgd', loss='mean_squared_error')
r = model.fit(x_train, y_train, validation_data=(x_test,y_test), epochs=50)
plt.plot(r.history['loss'], label="loss")
plt.plot(r.history['val_loss'], label="val_loss")
plt.show()

但是我的损失非常高-大约426470263086-并且从未随着时间减少。这是我的损失图 enter image description here

更新

这是我要处理的数据的一部分。

           id             date     price  bedrooms  ...      lat     long  sqft_living15  sqft_lot15
0  7129300520  20141013T000000  221900.0         3  ...  47.5112 -122.257           1340        5650
1  6414100192  20141209T000000  538000.0         3  ...  47.7210 -122.319           1690        7639
2  5631500400  20150225T000000  180000.0         2  ...  47.7379 -122.233           2720        8062
3  2487200875  20141209T000000  604000.0         4  ...  47.5208 -122.393           1360        5000
4  1954400510  20150218T000000  510000.0         3  ...  47.6168 -122.045           1800        7503

[5 rows x 21 columns]

enter image description here

1 个答案:

答案 0 :(得分:2)

您似乎正在尝试预测连续值。预测连续值时,您在最后一层中的激活应该是线性的,或者泄漏的relu(如果预测值是正值),否则就没有激活。

import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras import Sequential
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler,normalize

data = pd.read_csv("home_data.csv")
x = data.drop(['id','price' ,'date'], axis=1).values
y = data['price'].values

x_train, x_test, y_train, y_test = train_test_split(x,y,test_size=0.33)

scaler = StandardScaler()
x_train = scaler.fit_transform(x_train)
x_test = scaler.transform(x_test)

model = Sequential()
model.add(Dense(12, input_shape=(18,), activation="relu"))
model.add(Dense(6, activation="relu"))
model.add(Dense(1, activation="linear"))
model.compile(optimizer='sgd', loss='mean_squared_error', metrics = [tf.keras.metrics.RootMeanSquaredError()])
r = model.fit(x_train, y_train, validation_data=(x_test,y_test), epochs=10)

plt.plot(r.history['loss'], label="loss")
plt.plot(r.history['val_loss'], label="val_loss")
plt.show()

您不必为隐藏层指定输入形状。

由于数据集中的最小值和最大值存在较大差异,因此模型的计算损失非常高。

使用标准洁牙机后,损失减少了。

输出:

Epoch 1/10
453/453 [==============================] - 1s 3ms/step - loss: 6093344963084377128960.0000 - root_mean_squared_error: 78059880448.0000 - val_loss: 9416156905472.0000 - val_root_mean_squared_error: 3068575.7500
Epoch 2/10
453/453 [==============================] - 1s 3ms/step - loss: 639826591744.0000 - root_mean_squared_error: 799891.6250 - val_loss: 155623915520.0000 - val_root_mean_squared_error: 394491.9688
Epoch 3/10
453/453 [==============================] - 1s 2ms/step - loss: 124726026240.0000 - root_mean_squared_error: 353165.7188 - val_loss: 155318534144.0000 - val_root_mean_squared_error: 394104.7188
Epoch 4/10
453/453 [==============================] - 1s 3ms/step - loss: 124705193984.0000 - root_mean_squared_error: 353136.2188 - val_loss: 155418017792.0000 - val_root_mean_squared_error: 394230.9062
Epoch 5/10
453/453 [==============================] - 1s 3ms/step - loss: 124720766976.0000 - root_mean_squared_error: 353158.2812 - val_loss: 155389984768.0000 - val_root_mean_squared_error: 394195.3750
Epoch 6/10
453/453 [==============================] - 1s 3ms/step - loss: 124696051712.0000 - root_mean_squared_error: 353123.2812 - val_loss: 155291697152.0000 - val_root_mean_squared_error: 394070.6875
Epoch 7/10
453/453 [==============================] - 1s 3ms/step - loss: 124681125888.0000 - root_mean_squared_error: 353102.1562 - val_loss: 155307376640.0000 - val_root_mean_squared_error: 394090.5625
Epoch 8/10
453/453 [==============================] - 1s 3ms/step - loss: 124710920192.0000 - root_mean_squared_error: 353144.3438 - val_loss: 155327266816.0000 - val_root_mean_squared_error: 394115.8125
Epoch 9/10
453/453 [==============================] - 1s 3ms/step - loss: 124708052992.0000 - root_mean_squared_error: 353140.2812 - val_loss: 155288338432.0000 - val_root_mean_squared_error: 394066.4062
Epoch 10/10
453/453 [==============================] - 1s 3ms/step - loss: 124725968896.0000 - root_mean_squared_error: 353165.6250 - val_loss: 155315683328.0000 - val_root_mean_squared_error: 394101.0938