这是我非常基本的ANN代码:
import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras import Sequential
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler,normalize
data = pd.read_csv("home_data.csv")
x = data.drop(['id', 'date', 'price'], axis=1).values
y = data['price'].values
x_train, x_test, y_train, y_test = train_test_split(x,y,test_size=0.33)
model = Sequential()
model.add(Dense(18, input_shape=(18,), activation="sigmoid"))
model.add(Dense(36, input_shape=(18,), activation="sigmoid"))
model.add(Dense(1, input_shape=(18,), activation="sigmoid"))
model.compile(optimizer='sgd', loss='mean_squared_error')
r = model.fit(x_train, y_train, validation_data=(x_test,y_test), epochs=50)
plt.plot(r.history['loss'], label="loss")
plt.plot(r.history['val_loss'], label="val_loss")
plt.show()
但是我的损失非常高-大约426470263086-并且从未随着时间减少。这是我的损失图
更新
这是我要处理的数据的一部分。
id date price bedrooms ... lat long sqft_living15 sqft_lot15
0 7129300520 20141013T000000 221900.0 3 ... 47.5112 -122.257 1340 5650
1 6414100192 20141209T000000 538000.0 3 ... 47.7210 -122.319 1690 7639
2 5631500400 20150225T000000 180000.0 2 ... 47.7379 -122.233 2720 8062
3 2487200875 20141209T000000 604000.0 4 ... 47.5208 -122.393 1360 5000
4 1954400510 20150218T000000 510000.0 3 ... 47.6168 -122.045 1800 7503
[5 rows x 21 columns]
答案 0 :(得分:2)
您似乎正在尝试预测连续值。预测连续值时,您在最后一层中的激活应该是线性的,或者泄漏的relu(如果预测值是正值),否则就没有激活。
import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras import Sequential
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler,normalize
data = pd.read_csv("home_data.csv")
x = data.drop(['id','price' ,'date'], axis=1).values
y = data['price'].values
x_train, x_test, y_train, y_test = train_test_split(x,y,test_size=0.33)
scaler = StandardScaler()
x_train = scaler.fit_transform(x_train)
x_test = scaler.transform(x_test)
model = Sequential()
model.add(Dense(12, input_shape=(18,), activation="relu"))
model.add(Dense(6, activation="relu"))
model.add(Dense(1, activation="linear"))
model.compile(optimizer='sgd', loss='mean_squared_error', metrics = [tf.keras.metrics.RootMeanSquaredError()])
r = model.fit(x_train, y_train, validation_data=(x_test,y_test), epochs=10)
plt.plot(r.history['loss'], label="loss")
plt.plot(r.history['val_loss'], label="val_loss")
plt.show()
您不必为隐藏层指定输入形状。
由于数据集中的最小值和最大值存在较大差异,因此模型的计算损失非常高。
使用标准洁牙机后,损失减少了。
输出:
Epoch 1/10
453/453 [==============================] - 1s 3ms/step - loss: 6093344963084377128960.0000 - root_mean_squared_error: 78059880448.0000 - val_loss: 9416156905472.0000 - val_root_mean_squared_error: 3068575.7500
Epoch 2/10
453/453 [==============================] - 1s 3ms/step - loss: 639826591744.0000 - root_mean_squared_error: 799891.6250 - val_loss: 155623915520.0000 - val_root_mean_squared_error: 394491.9688
Epoch 3/10
453/453 [==============================] - 1s 2ms/step - loss: 124726026240.0000 - root_mean_squared_error: 353165.7188 - val_loss: 155318534144.0000 - val_root_mean_squared_error: 394104.7188
Epoch 4/10
453/453 [==============================] - 1s 3ms/step - loss: 124705193984.0000 - root_mean_squared_error: 353136.2188 - val_loss: 155418017792.0000 - val_root_mean_squared_error: 394230.9062
Epoch 5/10
453/453 [==============================] - 1s 3ms/step - loss: 124720766976.0000 - root_mean_squared_error: 353158.2812 - val_loss: 155389984768.0000 - val_root_mean_squared_error: 394195.3750
Epoch 6/10
453/453 [==============================] - 1s 3ms/step - loss: 124696051712.0000 - root_mean_squared_error: 353123.2812 - val_loss: 155291697152.0000 - val_root_mean_squared_error: 394070.6875
Epoch 7/10
453/453 [==============================] - 1s 3ms/step - loss: 124681125888.0000 - root_mean_squared_error: 353102.1562 - val_loss: 155307376640.0000 - val_root_mean_squared_error: 394090.5625
Epoch 8/10
453/453 [==============================] - 1s 3ms/step - loss: 124710920192.0000 - root_mean_squared_error: 353144.3438 - val_loss: 155327266816.0000 - val_root_mean_squared_error: 394115.8125
Epoch 9/10
453/453 [==============================] - 1s 3ms/step - loss: 124708052992.0000 - root_mean_squared_error: 353140.2812 - val_loss: 155288338432.0000 - val_root_mean_squared_error: 394066.4062
Epoch 10/10
453/453 [==============================] - 1s 3ms/step - loss: 124725968896.0000 - root_mean_squared_error: 353165.6250 - val_loss: 155315683328.0000 - val_root_mean_squared_error: 394101.0938